From 2f35b56cc6a6a1b1f5605a7caf2b8d84a7597cd6 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Wed, 13 May 2026 11:37:31 -0700 Subject: [PATCH 01/27] DBOSConfig.useListenNotify --- .../spring/DBOSAutoConfiguration.java | 1 + .../dbos/transact/spring/DBOSProperties.java | 14 +++ .../dev/dbos/transact/config/DBOSConfig.java | 88 ++++++++++++++----- 3 files changed, 81 insertions(+), 22 deletions(-) diff --git a/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSAutoConfiguration.java b/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSAutoConfiguration.java index 068de638..935b516c 100644 --- a/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSAutoConfiguration.java +++ b/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSAutoConfiguration.java @@ -142,6 +142,7 @@ private DBOSConfig buildConfig(DBOSProperties props, String springAppName) { config = config.withAdminServer(props.getAdminServer().isEnabled()); config = config.withAdminServerPort(props.getAdminServer().getPort()); config = config.withMigrate(props.getDatasource().isMigrate()); + config = config.withUseListenNotify(props.getDatasource().isUseListenNotify()); config = config.withEnablePatching(props.isEnablePatching()); List listenQueues = props.getListenQueues(); diff --git a/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSProperties.java b/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSProperties.java index 1d2baa2b..e24e32b2 100644 --- a/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSProperties.java +++ b/transact-spring-boot-starter/src/main/java/dev/dbos/transact/spring/DBOSProperties.java @@ -133,6 +133,12 @@ public static class Datasource { /** Whether to run database migrations on startup. Defaults to {@code true}. */ private boolean migrate = true; + /** + * Whether to use PostgreSQL LISTEN/NOTIFY for event delivery. Defaults to {@code true}. Set to + * {@code false} to use polling instead + */ + private boolean useListenNotify = true; + public String getUrl() { return url; } @@ -172,6 +178,14 @@ public boolean isMigrate() { public void setMigrate(boolean migrate) { this.migrate = migrate; } + + public boolean isUseListenNotify() { + return useListenNotify; + } + + public void setUseListenNotify(boolean useListenNotify) { + this.useListenNotify = useListenNotify; + } } public Application getApplication() { diff --git a/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java b/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java index 79b06be5..97f5200f 100644 --- a/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java +++ b/transact/src/main/java/dev/dbos/transact/config/DBOSConfig.java @@ -35,7 +35,8 @@ public record DBOSConfig( boolean enablePatching, @NonNull Set listenQueues, @Nullable DBOSSerializer serializer, - @Nullable Duration schedulerPollingInterval) { + @Nullable Duration schedulerPollingInterval, + boolean useListenNotify) { public DBOSConfig { if (appName == null || appName.isEmpty()) { @@ -80,7 +81,8 @@ public DBOSConfig(DBOSConfig other) { other.enablePatching, (other.listenQueues == null ? null : Set.copyOf(other.listenQueues)), other.serializer, - other.schedulerPollingInterval); + other.schedulerPollingInterval, + other.useListenNotify); } public static @NonNull DBOSConfig defaults(@NonNull String appName) { @@ -88,7 +90,7 @@ public DBOSConfig(DBOSConfig other) { appName, null, null, null, null, false, // adminServer 3001, // adminServerPort true, // migrate - null, null, null, null, null, null, false, null, null, null); + null, null, null, null, null, null, false, null, null, null, true); // useListenNotify } public static @NonNull DBOSConfig defaultsFromEnv(@NonNull String appName) { @@ -121,7 +123,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withDatabaseUrl(@Nullable String v) { @@ -143,7 +146,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withDbUser(@Nullable String v) { @@ -165,7 +169,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withDbPassword(@Nullable String v) { @@ -187,7 +192,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withDataSource(@Nullable DataSource v) { @@ -209,7 +215,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withAdminServer(boolean v) { @@ -231,7 +238,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withAdminServerPort(int v) { @@ -253,7 +261,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withMigrate(boolean v) { @@ -275,7 +284,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withConductorKey(@Nullable String v) { @@ -297,7 +307,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withConductorDomain(@Nullable String v) { @@ -319,7 +330,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withConductorExecutorMetadata(@Nullable Map v) { @@ -341,7 +353,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withAppVersion(@Nullable String v) { @@ -363,7 +376,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withExecutorId(@Nullable String v) { @@ -385,7 +399,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withDatabaseSchema(@Nullable String v) { @@ -407,7 +422,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withEnablePatching() { @@ -437,7 +453,8 @@ public DBOSConfig(DBOSConfig other) { v, listenQueues, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig enableAdminServer() { @@ -489,7 +506,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, v, serializer, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withSerializer(@Nullable DBOSSerializer v) { @@ -511,7 +529,8 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, v, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } public @NonNull DBOSConfig withSchedulerPollingInterval(@Nullable Duration v) { @@ -533,6 +552,30 @@ public DBOSConfig(DBOSConfig other) { enablePatching, listenQueues, serializer, + v, + useListenNotify); + } + + public @NonNull DBOSConfig withUseListenNotify(boolean v) { + return new DBOSConfig( + appName, + databaseUrl, + dbUser, + dbPassword, + dataSource, + adminServer, + adminServerPort, + migrate, + conductorKey, + conductorDomain, + conductorExecutorMetadata, + appVersion, + executorId, + databaseSchema, + enablePatching, + listenQueues, + serializer, + schedulerPollingInterval, v); } @@ -544,7 +587,7 @@ public String toString() { dataSource=%s, databaseSchema=%s, adminServer=%s, adminServerPort=%d, \ migrate=%s, conductorKey=%s, conductorDomain=%s, \ conductorExecutorMetadata=%s, appVersion=%s, executorId=%s, enablePatching=%s, \ - listenQueues=%s, serializer=%s, schedulerPollingInterval=%s] + listenQueues=%s, serializer=%s, schedulerPollingInterval=%s, useListenNotify=%s] """ .formatted( appName, @@ -563,6 +606,7 @@ public String toString() { enablePatching, listenQueues, serializer != null ? serializer.name() : null, - schedulerPollingInterval); + schedulerPollingInterval, + useListenNotify); } } From c74a7aad6071feda7a4451b4867dc8a0c3ef6bd1 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Wed, 13 May 2026 13:15:10 -0700 Subject: [PATCH 02/27] use_listen_notify migration 1 --- .../transact/migrations/MigrationManager.java | 98 ++++++++++++------- .../migrations/MigrationManagerTest.java | 91 ++++++++++++++++- 2 files changed, 148 insertions(+), 41 deletions(-) diff --git a/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java b/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java index ce6597cd..98786cb2 100644 --- a/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java +++ b/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java @@ -23,29 +23,30 @@ public static void runMigrations(DBOSConfig config) { Objects.requireNonNull(config, "DBOS Config must not be null"); if (config.dataSource() != null) { - runMigrations(config.dataSource(), config.databaseSchema()); + runMigrations(config.dataSource(), config.databaseSchema(), config.useListenNotify()); } else { createDatabaseIfNotExists(config.databaseUrl(), config.dbUser(), config.dbPassword()); try (var ds = SystemDatabase.createDataSource( config.databaseUrl(), config.dbUser(), config.dbPassword())) { - runMigrations(ds, config.databaseSchema()); + runMigrations(ds, config.databaseSchema(), config.useListenNotify()); } } } - public static void runMigrations(String url, String user, String password, String schema) { + public static void runMigrations( + String url, String user, String password, String schema, boolean useListenNotify) { Objects.requireNonNull(url, "database url must not be null"); Objects.requireNonNull(user, "database user must not be null"); Objects.requireNonNull(password, "database password must not be null"); createDatabaseIfNotExists(url, user, password); try (var ds = SystemDatabase.createDataSource(url, user, password)) { - runMigrations(ds, schema); + runMigrations(ds, schema, useListenNotify); } } - private static void runMigrations(DataSource ds, String schema) { + private static void runMigrations(DataSource ds, String schema, boolean useListenNotify) { Objects.requireNonNull(ds, "Data Source must not be null"); schema = SystemDatabase.sanitizeSchema(schema); @@ -54,9 +55,23 @@ private static void runMigrations(DataSource ds, String schema) { } try (var conn = ds.getConnection()) { + + var isCockroach = false; + try (var stmt = conn.createStatement(); + var rs = stmt.executeQuery("SELECT version()")) { + if (rs.next()) { + String version = rs.getString(1).toLowerCase(); + isCockroach = version.contains("cockroachdb"); + } + } + + if (isCockroach) { + useListenNotify = false; + } + ensureDbosSchema(conn, schema); ensureMigrationTable(conn, schema); - var migrations = getMigrations(schema); + var migrations = getMigrations(schema, useListenNotify); runDbosMigrations(conn, schema, migrations); } catch (SQLException e) { throw new RuntimeException("Failed to run migrations", e); @@ -212,11 +227,12 @@ static void runDbosMigrations(Connection conn, String schema, List migra } } - public static List getMigrations(String schema) { + public static List getMigrations(String schema, boolean useListenNotify) { Objects.requireNonNull(schema); + var migrations = List.of( - MIGRATION_1, + migration1(useListenNotify), MIGRATION_2, MIGRATION_3, MIGRATION_4, @@ -238,8 +254,13 @@ public static List getMigrations(String schema) { return migrations.stream().map(m -> m.formatted(schema)).toList(); } + static String migration1(boolean useListenNotify) { + return useListenNotify ? MIGRATION_1 + MIGRATION_1_NOTIFY : MIGRATION_1; + } + static final String MIGRATION_1 = """ + -- Enable uuid extension for generating UUIDs CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; CREATE TABLE "%1$s".workflow_status ( @@ -300,6 +321,38 @@ FOREIGN KEY (destination_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) ); CREATE INDEX idx_workflow_topic ON "%1$s".notifications (destination_uuid, topic); + CREATE TABLE "%1$s".workflow_events ( + workflow_uuid TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + PRIMARY KEY (workflow_uuid, key), + FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) + ON UPDATE CASCADE ON DELETE CASCADE + ); + + CREATE TABLE "%1$s".streams ( + workflow_uuid TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + "offset" INT4 NOT NULL, + PRIMARY KEY (workflow_uuid, key, "offset"), + FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) + ON UPDATE CASCADE ON DELETE CASCADE + ); + + CREATE TABLE "%1$s".event_dispatch_kv ( + service_name TEXT NOT NULL, + workflow_fn_name TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT, + update_seq NUMERIC(38,0), + update_time NUMERIC(38,15), + PRIMARY KEY (service_name, workflow_fn_name, key) + ); + """; + + static final String MIGRATION_1_NOTIFY = + """ -- Create notification function CREATE OR REPLACE FUNCTION "%1$s".notifications_function() RETURNS TRIGGER AS $$ DECLARE @@ -315,15 +368,6 @@ FOREIGN KEY (destination_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) AFTER INSERT ON "%1$s".notifications FOR EACH ROW EXECUTE FUNCTION "%1$s".notifications_function(); - CREATE TABLE "%1$s".workflow_events ( - workflow_uuid TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - PRIMARY KEY (workflow_uuid, key), - FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) - ON UPDATE CASCADE ON DELETE CASCADE - ); - -- Create events function CREATE OR REPLACE FUNCTION "%1$s".workflow_events_function() RETURNS TRIGGER AS $$ DECLARE @@ -338,26 +382,6 @@ FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) CREATE TRIGGER dbos_workflow_events_trigger AFTER INSERT ON "%1$s".workflow_events FOR EACH ROW EXECUTE FUNCTION "%1$s".workflow_events_function(); - - CREATE TABLE "%1$s".streams ( - workflow_uuid TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - "offset" INT4 NOT NULL, - PRIMARY KEY (workflow_uuid, key, "offset"), - FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) - ON UPDATE CASCADE ON DELETE CASCADE - ); - - CREATE TABLE "%1$s".event_dispatch_kv ( - service_name TEXT NOT NULL, - workflow_fn_name TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT, - update_seq NUMERIC(38,0), - update_time NUMERIC(38,15), - PRIMARY KEY (service_name, workflow_fn_name, key) - ); """; static final String MIGRATION_2 = diff --git a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java index 1520ca39..88664f84 100644 --- a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java +++ b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java @@ -43,6 +43,11 @@ class MigrationManagerTest { "notifications_function", "workflow_events_function", "enqueue_workflow", "send_message" }; + // Expected LISTEN/NOTIFY triggers after migrations (only when useListenNotify=true) + static final String[] EXPECTED_NOTIFY_TRIGGERS = { + "dbos_notifications_trigger", "dbos_workflow_events_trigger" + }; + @AutoClose final PgContainer pgContainer = new PgContainer(); @AutoClose HikariDataSource dataSource; @@ -69,13 +74,41 @@ void testRunMigrations_CreatesTables() throws Exception { for (String function : EXPECTED_FUNCTIONS) { assertFunctionExists(metaData, function); } + for (String trigger : EXPECTED_NOTIFY_TRIGGERS) { + assertTriggerExists(conn, trigger); + } - var migrations = new ArrayList<>(MigrationManager.getMigrations(Constants.DB_SCHEMA)); + var migrations = new ArrayList<>(MigrationManager.getMigrations(Constants.DB_SCHEMA, true)); var version = getVersion(conn); assertEquals(migrations.size(), version); } } + @Test + void testRunMigrations_NoNotify_OmitsTriggersAndFunctions() throws Exception { + var dbosConfig = pgContainer.dbosConfig().withUseListenNotify(false); + MigrationManager.runMigrations(dbosConfig); + + try (Connection conn = dataSource.getConnection()) { + DatabaseMetaData metaData = conn.getMetaData(); + + // All tables should still exist + for (String table : EXPECTED_TABLES) { + assertTableExists(metaData, table); + } + + // enqueue_workflow and send_message are unaffected + assertFunctionExists(metaData, "enqueue_workflow"); + assertFunctionExists(metaData, "send_message"); + + // LISTEN/NOTIFY functions and triggers must be absent + assertFunctionAbsent(conn, "notifications_function"); + assertFunctionAbsent(conn, "workflow_events_function"); + assertTriggerAbsent(conn, "dbos_notifications_trigger"); + assertTriggerAbsent(conn, "dbos_workflow_events_trigger"); + } + } + @ParameterizedTest @ValueSource(strings = {"invalid\"schema", "invalid'schema"}) void testRunMigrations_fails_invalid_schema(String invalidSchema) throws Exception { @@ -101,8 +134,11 @@ void testRunMigrations_customSchema(String schema) throws Exception { for (String function : EXPECTED_FUNCTIONS) { assertFunctionExists(metaData, function, schema); } + for (String trigger : EXPECTED_NOTIFY_TRIGGERS) { + assertTriggerExists(conn, trigger, schema); + } - var migrations = new ArrayList<>(MigrationManager.getMigrations(schema)); + var migrations = new ArrayList<>(MigrationManager.getMigrations(schema, true)); var version = getVersion(conn, schema); assertEquals(migrations.size(), version); } @@ -126,7 +162,7 @@ void testRunMigrations_IsIdempotent() throws Exception { void testAddingNewMigration() throws Exception { testRunMigrations_CreatesTables(); - var migrations = new ArrayList<>(MigrationManager.getMigrations(Constants.DB_SCHEMA)); + var migrations = new ArrayList<>(MigrationManager.getMigrations(Constants.DB_SCHEMA, true)); migrations.add("CREATE TABLE dummy_table(id SERIAL PRIMARY KEY);"); try (var conn = dataSource.getConnection()) { @@ -192,7 +228,7 @@ void testOriginalMigration1ThenAllMigrations_NotificationsPrimaryKey() throws Ex assertTableExists(metaData, "notifications"); // Now run all current migrations (including migration10 which ensures primary key) - var allMigrations = MigrationManager.getMigrations(Constants.DB_SCHEMA); + var allMigrations = MigrationManager.getMigrations(Constants.DB_SCHEMA, true); MigrationManager.runDbosMigrations(conn, Constants.DB_SCHEMA, allMigrations); // Verify that the notifications table has a primary key @@ -246,6 +282,53 @@ static int getVersion(Connection conn, String schema) throws Exception { } } + static void assertFunctionAbsent(Connection conn, String functionName) throws Exception { + String sql = + "SELECT 1 FROM pg_proc p JOIN pg_namespace n ON p.pronamespace = n.oid" + + " WHERE n.nspname = ? AND p.proname = ?"; + try (var ps = conn.prepareStatement(sql)) { + ps.setString(1, Constants.DB_SCHEMA); + ps.setString(2, functionName); + try (var rs = ps.executeQuery()) { + assertFalse(rs.next(), "Function %s should not exist".formatted(functionName)); + } + } + } + + static void assertTriggerExists(Connection conn, String triggerName) throws Exception { + assertTriggerExists(conn, triggerName, Constants.DB_SCHEMA); + } + + static void assertTriggerExists(Connection conn, String triggerName, String schema) + throws Exception { + schema = SystemDatabase.sanitizeSchema(schema); + String sql = + "SELECT 1 FROM pg_trigger t JOIN pg_class c ON t.tgrelid = c.oid" + + " JOIN pg_namespace n ON c.relnamespace = n.oid" + + " WHERE n.nspname = ? AND t.tgname = ?"; + try (var ps = conn.prepareStatement(sql)) { + ps.setString(1, schema); + ps.setString(2, triggerName); + try (var rs = ps.executeQuery()) { + assertTrue(rs.next(), "Trigger %s should exist in schema %s".formatted(triggerName, schema)); + } + } + } + + static void assertTriggerAbsent(Connection conn, String triggerName) throws Exception { + String sql = + "SELECT 1 FROM pg_trigger t JOIN pg_class c ON t.tgrelid = c.oid" + + " JOIN pg_namespace n ON c.relnamespace = n.oid" + + " WHERE n.nspname = ? AND t.tgname = ?"; + try (var ps = conn.prepareStatement(sql)) { + ps.setString(1, Constants.DB_SCHEMA); + ps.setString(2, triggerName); + try (var rs = ps.executeQuery()) { + assertFalse(rs.next(), "Trigger %s should not exist".formatted(triggerName)); + } + } + } + static void assertNotificationTableHasPrimaryKey( DatabaseMetaData metaData, String tableName, String schemaName) throws Exception { try (ResultSet rs = metaData.getPrimaryKeys(null, schemaName, tableName)) { From 6b1c30a8ee154dc7eb445c1d3723eb50952266ab Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Wed, 13 May 2026 14:22:55 -0700 Subject: [PATCH 03/27] signal registry --- .../transact/database/SignalRegistry.java | 26 ++ .../transact/database/SignalRegistryTest.java | 385 ++++++++++++++++++ .../migrations/MigrationManagerTest.java | 3 +- 3 files changed, 413 insertions(+), 1 deletion(-) create mode 100644 transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java create mode 100644 transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java diff --git a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java new file mode 100644 index 00000000..ab456a6c --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java @@ -0,0 +1,26 @@ +package dev.dbos.transact.database; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; + +class SignalRegistry { + + private final ConcurrentHashMap> map = new ConcurrentHashMap<>(); + + CompletableFuture subscribe(String key) { + return map.computeIfAbsent(key, k -> new CompletableFuture<>()).copy(); + } + + void signal(String key) { + CompletableFuture f = map.remove(key); + if (f != null) f.complete(null); + } + + void unsubscribe(String key) { + map.remove(key); + } + + static CompletableFuture never() { + return new CompletableFuture<>(); + } +} diff --git a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java b/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java new file mode 100644 index 00000000..f087d574 --- /dev/null +++ b/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java @@ -0,0 +1,385 @@ +package dev.dbos.transact.database; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +class SignalRegistryTest { + + SignalRegistry registry; + + @BeforeEach + void setup() { + registry = new SignalRegistry(); + } + + @Test + void testBasicSubscribeAndSignal() { + CompletableFuture f = registry.subscribe("key"); + assertFalse(f.isDone()); + registry.signal("key"); + assertTrue(f.isDone()); + assertFalse(f.isCompletedExceptionally()); + } + + @Test + void testMultipleListenersOnSameKey() { + // Multiple subscribers on the same key must all complete when signal fires + CompletableFuture f1 = registry.subscribe("key"); + CompletableFuture f2 = registry.subscribe("key"); + CompletableFuture f3 = registry.subscribe("key"); + + assertFalse(f1.isDone()); + assertFalse(f2.isDone()); + assertFalse(f3.isDone()); + + registry.signal("key"); + + assertTrue(f1.isDone()); + assertTrue(f2.isDone()); + assertTrue(f3.isDone()); + assertFalse(f1.isCompletedExceptionally()); + assertFalse(f2.isCompletedExceptionally()); + assertFalse(f3.isCompletedExceptionally()); + } + + @Test + void testMultipleSubscriptionsInAnyOf() { + // anyOf(sub1, sub2) — signalling sub2's key should complete the anyOf block + CompletableFuture f1 = registry.subscribe("key-one"); + CompletableFuture f2 = registry.subscribe("key-two"); + + CompletableFuture anyOf = CompletableFuture.anyOf(f1, f2); + assertFalse(anyOf.isDone()); + + registry.signal("key-two"); + + assertTrue(anyOf.isDone()); + assertTrue(f2.isDone()); + assertFalse(f1.isDone()); // key-one was never signalled + } + + @Test + void testSignalOnlyWakesMatchingKey() { + CompletableFuture f1 = registry.subscribe("key-one"); + CompletableFuture f2 = registry.subscribe("key-two"); + + registry.signal("key-one"); + + assertTrue(f1.isDone()); + assertFalse(f2.isDone()); + } + + @Test + void testSignalBeforeSubscribeDoesNotWake() { + // signal fires with no subscribers; subsequent subscribe gets a fresh future + registry.signal("key"); + + CompletableFuture f = registry.subscribe("key"); + assertFalse(f.isDone()); + + // verify it can still be signalled normally afterwards + registry.signal("key"); + assertTrue(f.isDone()); + } + + @Test + void testSignalIsOneShot() { + CompletableFuture f1 = registry.subscribe("key"); + registry.signal("key"); + assertTrue(f1.isDone()); + + // second signal on same key — new subscriber should need a new signal + CompletableFuture f2 = registry.subscribe("key"); + assertFalse(f2.isDone()); + } + + @Test + void testUnsubscribePreventsFutureFromBeingSignalled() throws Exception { + CompletableFuture f = registry.subscribe("key"); + registry.unsubscribe("key"); + registry.signal("key"); // no entry in map — should be a no-op + + // f was returned before unsubscribe so it's a copy of the (now-removed) shared future; + // it will never complete since nothing holds a reference to complete it + boolean completed = f.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); + assertFalse(completed); + } + + @Test + void testNeverFutureNeverCompletes() throws Exception { + CompletableFuture f = SignalRegistry.never(); + assertFalse(f.isDone()); + + boolean completed = f.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); + assertFalse(completed); + } + + @Test + void testSubscribeBeforeSignalFromAnotherThread() throws Exception { + // Subscribe on the current thread, signal from a background thread after a delay. + CompletableFuture f = registry.subscribe("foo"); + + CompletableFuture.runAsync( + () -> { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + registry.signal("foo"); + }); + + assertTimeoutPreemptively(Duration.ofSeconds(1), (Executable) f::get); + } + + @Test + void testSignalFromMainAfterBackgroundSubscribes() throws Exception { + // Background thread subscribes and waits; main thread signals after a delay. + CompletableFuture backgroundDone = new CompletableFuture<>(); + + CompletableFuture.runAsync( + () -> { + CompletableFuture f = registry.subscribe("foo"); + try { + f.get(500, TimeUnit.MILLISECONDS); + backgroundDone.complete(null); + } catch (Exception e) { + backgroundDone.completeExceptionally(e); + } + }); + + Thread.sleep(100); + registry.signal("foo"); + + assertTimeoutPreemptively(Duration.ofSeconds(1), (Executable) backgroundDone::get); + } + + @Test + void testMultipleSubscribersInSeparateThreads() throws Exception { + // Two background threads each subscribe; a single signal wakes both. + CompletableFuture done1 = new CompletableFuture<>(); + CompletableFuture done2 = new CompletableFuture<>(); + + CompletableFuture.runAsync( + () -> { + try { + registry.subscribe("foo").get(500, TimeUnit.MILLISECONDS); + done1.complete(null); + } catch (Exception e) { + done1.completeExceptionally(e); + } + }); + CompletableFuture.runAsync( + () -> { + try { + registry.subscribe("foo").get(500, TimeUnit.MILLISECONDS); + done2.complete(null); + } catch (Exception e) { + done2.completeExceptionally(e); + } + }); + + Thread.sleep(100); + registry.signal("foo"); + + assertTimeoutPreemptively( + Duration.ofSeconds(1), + () -> { + done1.get(); + done2.get(); + }); + } + + @Test + void testConcurrentSignalAndSubscribe() throws Exception { + // Stress: signal and subscribe racing from different threads should not deadlock or lose + // wakeups + assertTimeoutPreemptively( + Duration.ofSeconds(5), + () -> { + for (int i = 0; i < 1000; i++) { + SignalRegistry r = new SignalRegistry(); + CompletableFuture sub = r.subscribe("key"); + CompletableFuture.runAsync(() -> r.signal("key")); + sub.get(1, TimeUnit.SECONDS); + } + }); + } + + // --- anyOf determination via checking isDone + + @Test + void testCheckIsDone_notifyFires() throws Exception { + CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onCancelled = registry.subscribe("cancel-key"); + CompletableFuture onDbClosed = SignalRegistry.never(); + + registry.signal("notify-key"); + + try { + CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(1, TimeUnit.SECONDS); + } catch (java.util.concurrent.TimeoutException ignored) { + } + + assertFalse(onCancelled.isDone() || onDbClosed.isDone()); // routes to re-check-DB branch + assertTrue(onNotify.isDone()); + } + + @Test + void testCheckIsDone_cancelledFires() throws Exception { + CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onCancelled = registry.subscribe("cancel-key"); + CompletableFuture onDbClosed = SignalRegistry.never(); + + registry.signal("cancel-key"); + + try { + CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(1, TimeUnit.SECONDS); + } catch (java.util.concurrent.TimeoutException ignored) { + } + + assertTrue(onCancelled.isDone() || onDbClosed.isDone()); // routes to return-null branch + assertFalse(onNotify.isDone()); + } + + @Test + void testCheckIsDone_dbClosedFires() throws Exception { + CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onCancelled = SignalRegistry.never(); + CompletableFuture onDbClosed = new CompletableFuture<>(); + onDbClosed.complete(null); + + try { + CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(1, TimeUnit.SECONDS); + } catch (java.util.concurrent.TimeoutException ignored) { + } + + assertTrue(onCancelled.isDone() || onDbClosed.isDone()); // routes to return-null branch + assertFalse(onNotify.isDone()); + } + + @Test + void testCheckIsDone_timeout() throws Exception { + CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onCancelled = SignalRegistry.never(); + CompletableFuture onDbClosed = SignalRegistry.never(); + + try { + CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(50, TimeUnit.MILLISECONDS); + } catch (java.util.concurrent.TimeoutException ignored) { + } + + assertFalse(onCancelled.isDone() || onDbClosed.isDone()); // routes to re-check-DB branch + assertFalse(onNotify.isDone()); + } + + // --- anyOf determination via checking isDone via tagged dispatch + + @Test + void testTaggedDispatch_notifyFires() throws Exception { + CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onCancelled = registry.subscribe("cancel-key"); + CompletableFuture onDbClosed = SignalRegistry.never(); + + registry.signal("notify-key"); + + enum WakeReason { + NOTIFY, + CANCELLED, + DB_CLOSED + } + WakeReason reason = + (WakeReason) + CompletableFuture.anyOf( + onNotify.thenApply(v -> WakeReason.NOTIFY), + onCancelled.thenApply(v -> WakeReason.CANCELLED), + onDbClosed.thenApply(v -> WakeReason.DB_CLOSED)) + .get(1, TimeUnit.SECONDS); + + assertEquals(WakeReason.NOTIFY, reason); + } + + @Test + void testTaggedDispatch_cancelledFires() throws Exception { + CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onCancelled = registry.subscribe("cancel-key"); + CompletableFuture onDbClosed = SignalRegistry.never(); + + registry.signal("cancel-key"); + + enum WakeReason { + NOTIFY, + CANCELLED, + DB_CLOSED + } + WakeReason reason = + (WakeReason) + CompletableFuture.anyOf( + onNotify.thenApply(v -> WakeReason.NOTIFY), + onCancelled.thenApply(v -> WakeReason.CANCELLED), + onDbClosed.thenApply(v -> WakeReason.DB_CLOSED)) + .get(1, TimeUnit.SECONDS); + + assertEquals(WakeReason.CANCELLED, reason); + } + + @Test + void testTaggedDispatch_dbClosedFires() throws Exception { + CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onCancelled = SignalRegistry.never(); + CompletableFuture onDbClosed = new CompletableFuture<>(); + onDbClosed.complete(null); + + enum WakeReason { + NOTIFY, + CANCELLED, + DB_CLOSED + } + WakeReason reason = + (WakeReason) + CompletableFuture.anyOf( + onNotify.thenApply(v -> WakeReason.NOTIFY), + onCancelled.thenApply(v -> WakeReason.CANCELLED), + onDbClosed.thenApply(v -> WakeReason.DB_CLOSED)) + .get(1, TimeUnit.SECONDS); + + assertEquals(WakeReason.DB_CLOSED, reason); + } + + @Test + void testTaggedDispatch_timeout() throws Exception { + CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onCancelled = SignalRegistry.never(); + CompletableFuture onDbClosed = SignalRegistry.never(); + + enum WakeReason { + NOTIFY, + CANCELLED, + DB_CLOSED + } + WakeReason reason = null; + try { + reason = + (WakeReason) + CompletableFuture.anyOf( + onNotify.thenApply(v -> WakeReason.NOTIFY), + onCancelled.thenApply(v -> WakeReason.CANCELLED), + onDbClosed.thenApply(v -> WakeReason.DB_CLOSED)) + .get(50, TimeUnit.MILLISECONDS); + } catch (java.util.concurrent.TimeoutException ignored) { + } + + assertFalse(onNotify.isDone()); + assertEquals(null, reason); // anyOf threw — no wake reason was tagged + } +} diff --git a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java index 88664f84..05c0e33e 100644 --- a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java +++ b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java @@ -310,7 +310,8 @@ static void assertTriggerExists(Connection conn, String triggerName, String sche ps.setString(1, schema); ps.setString(2, triggerName); try (var rs = ps.executeQuery()) { - assertTrue(rs.next(), "Trigger %s should exist in schema %s".formatted(triggerName, schema)); + assertTrue( + rs.next(), "Trigger %s should exist in schema %s".formatted(triggerName, schema)); } } } From 034115219a46031dea4aeed71a8e68dec00df9e1 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Wed, 13 May 2026 14:49:01 -0700 Subject: [PATCH 04/27] improve signal registry --- .../transact/database/SignalRegistry.java | 51 ++++++++++++++++--- .../transact/database/SignalRegistryTest.java | 25 +++++++-- 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java index ab456a6c..05717cda 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java +++ b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java @@ -2,22 +2,57 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; class SignalRegistry { - private final ConcurrentHashMap> map = new ConcurrentHashMap<>(); + private static class Entry { + final CompletableFuture future = new CompletableFuture<>(); + final AtomicInteger refs = new AtomicInteger(1); + } + + static class Subscription extends CompletableFuture implements AutoCloseable { + private final Runnable onClose; + + Subscription(Runnable onClose) { + this.onClose = onClose; + } - CompletableFuture subscribe(String key) { - return map.computeIfAbsent(key, k -> new CompletableFuture<>()).copy(); + @Override + public void close() { + onClose.run(); + } } - void signal(String key) { - CompletableFuture f = map.remove(key); - if (f != null) f.complete(null); + private final ConcurrentHashMap map = new ConcurrentHashMap<>(); + + Subscription subscribe(String key) { + Entry entry = + map.compute( + key, + (k, e) -> { + if (e != null) { + e.refs.incrementAndGet(); + return e; + } + return new Entry(); + }); + Subscription sub = + new Subscription( + () -> + map.compute( + key, + (k, e) -> { + if (e != null && e.refs.decrementAndGet() == 0) return null; + return e; + })); + entry.future.thenRun(() -> sub.complete(null)); + return sub; } - void unsubscribe(String key) { - map.remove(key); + void signal(String key) { + Entry e = map.remove(key); + if (e != null) e.future.complete(null); } static CompletableFuture never() { diff --git a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java b/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java index f087d574..d9f3e189 100644 --- a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java +++ b/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java @@ -104,14 +104,29 @@ void testSignalIsOneShot() { } @Test - void testUnsubscribePreventsFutureFromBeingSignalled() throws Exception { - CompletableFuture f = registry.subscribe("key"); - registry.unsubscribe("key"); + void testCloseOneSubscriberDoesNotOrphanOthers() throws Exception { + // Closing one subscription on a shared key must not prevent the remaining subscriber + // from being woken when the signal fires (ref-counting behaviour). + SignalRegistry.Subscription sub1 = registry.subscribe("key"); + SignalRegistry.Subscription sub2 = registry.subscribe("key"); + + sub1.close(); // ref count drops to 1 — key must stay in map + + registry.signal("key"); + assertTrue(sub2.isDone()); + assertFalse(sub2.isCompletedExceptionally()); + } + + @Test + void testClosePreventsFutureFromBeingSignalled() throws Exception { + SignalRegistry.Subscription sub = registry.subscribe("key"); + sub.close(); registry.signal("key"); // no entry in map — should be a no-op - // f was returned before unsubscribe so it's a copy of the (now-removed) shared future; + // sub was closed before signal so the shared future was removed from the map; // it will never complete since nothing holds a reference to complete it - boolean completed = f.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); + boolean completed = + sub.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); assertFalse(completed); } From c2c70d326e6dffd158387aa1232e5c86e987663f Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Wed, 13 May 2026 16:19:43 -0700 Subject: [PATCH 05/27] SignalKey --- .../dev/dbos/transact/cli/MigrateCommand.java | 3 +- .../transact/database/SignalRegistry.java | 25 +- .../transact/database/SignalRegistryTest.java | 261 +++++++++++------- 3 files changed, 186 insertions(+), 103 deletions(-) diff --git a/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java b/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java index 995b8a99..b5b90040 100644 --- a/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java +++ b/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java @@ -38,8 +38,9 @@ public Integer call() throws Exception { out.format(" System Database: %s\n", dbOptions.url()); out.format(" System Database User: %s\n", dbOptions.user()); + // TODO: real fix MigrationManager.runMigrations( - dbOptions.url(), dbOptions.user(), dbOptions.password(), dbOptions.schema()); + dbOptions.url(), dbOptions.user(), dbOptions.password(), dbOptions.schema(), true); grantDBOSSchemaPermissions(out, dbOptions.schema()); return 0; } diff --git a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java index 05717cda..80037f7c 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java +++ b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java @@ -4,6 +4,17 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +sealed interface SignalKey + permits SignalKey.Cancellation, SignalKey.Event, SignalKey.Message, SignalKey.Shutdown { + + record Cancellation(String workflowId) implements SignalKey {} + + record Event(String workflowId, String topic) implements SignalKey {} + record Message(String workflowId, String topic) implements SignalKey {} + record Shutdown() implements SignalKey {} + +} + class SignalRegistry { private static class Entry { @@ -24,9 +35,9 @@ public void close() { } } - private final ConcurrentHashMap map = new ConcurrentHashMap<>(); + private final ConcurrentHashMap map = new ConcurrentHashMap<>(); - Subscription subscribe(String key) { + public Subscription subscribe(SignalKey key) { Entry entry = map.compute( key, @@ -50,12 +61,16 @@ Subscription subscribe(String key) { return sub; } - void signal(String key) { + public void signal(SignalKey key) { Entry e = map.remove(key); if (e != null) e.future.complete(null); } - static CompletableFuture never() { - return new CompletableFuture<>(); + Iterable keys() { + return map.keySet(); + } + + static Subscription never() { + return new Subscription(() -> {}); } } diff --git a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java b/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java index d9f3e189..5f83a82d 100644 --- a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java +++ b/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java @@ -2,6 +2,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -17,32 +18,52 @@ class SignalRegistryTest { SignalRegistry registry; + // Reusable keys + static final SignalKey KEY = new SignalKey.Cancellation("wf-1"); + static final SignalKey KEY_A = new SignalKey.Cancellation("wf-a"); + static final SignalKey KEY_B = new SignalKey.Cancellation("wf-b"); + static final SignalKey FOO = new SignalKey.Cancellation("foo"); + @BeforeEach void setup() { registry = new SignalRegistry(); } + // --- SignalKey structural equality --- + + @Test + void testSignalKeyStructuralEquality() { + assertEquals(new SignalKey.Cancellation("wf-1"), new SignalKey.Cancellation("wf-1")); + assertEquals(new SignalKey.Event("wf-1", "t"), new SignalKey.Event("wf-1", "t")); + assertNotEquals(new SignalKey.Cancellation("wf-1"), new SignalKey.Cancellation("wf-2")); + // Same fields, different types — must not be equal (prevents cross-type collisions) + assertNotEquals( + (SignalKey) new SignalKey.Cancellation("wf-1"), + (SignalKey) new SignalKey.Event("wf-1", "wf-1")); + } + + // --- Core subscribe / signal behaviour --- + @Test void testBasicSubscribeAndSignal() { - CompletableFuture f = registry.subscribe("key"); + CompletableFuture f = registry.subscribe(KEY); assertFalse(f.isDone()); - registry.signal("key"); + registry.signal(KEY); assertTrue(f.isDone()); assertFalse(f.isCompletedExceptionally()); } @Test void testMultipleListenersOnSameKey() { - // Multiple subscribers on the same key must all complete when signal fires - CompletableFuture f1 = registry.subscribe("key"); - CompletableFuture f2 = registry.subscribe("key"); - CompletableFuture f3 = registry.subscribe("key"); + CompletableFuture f1 = registry.subscribe(KEY); + CompletableFuture f2 = registry.subscribe(KEY); + CompletableFuture f3 = registry.subscribe(KEY); assertFalse(f1.isDone()); assertFalse(f2.isDone()); assertFalse(f3.isDone()); - registry.signal("key"); + registry.signal(KEY); assertTrue(f1.isDone()); assertTrue(f2.isDone()); @@ -54,77 +75,88 @@ void testMultipleListenersOnSameKey() { @Test void testMultipleSubscriptionsInAnyOf() { - // anyOf(sub1, sub2) — signalling sub2's key should complete the anyOf block - CompletableFuture f1 = registry.subscribe("key-one"); - CompletableFuture f2 = registry.subscribe("key-two"); + CompletableFuture f1 = registry.subscribe(KEY_A); + CompletableFuture f2 = registry.subscribe(KEY_B); CompletableFuture anyOf = CompletableFuture.anyOf(f1, f2); assertFalse(anyOf.isDone()); - registry.signal("key-two"); + registry.signal(KEY_B); assertTrue(anyOf.isDone()); assertTrue(f2.isDone()); - assertFalse(f1.isDone()); // key-one was never signalled + assertFalse(f1.isDone()); } @Test void testSignalOnlyWakesMatchingKey() { - CompletableFuture f1 = registry.subscribe("key-one"); - CompletableFuture f2 = registry.subscribe("key-two"); + CompletableFuture f1 = registry.subscribe(KEY_A); + CompletableFuture f2 = registry.subscribe(KEY_B); - registry.signal("key-one"); + registry.signal(KEY_A); assertTrue(f1.isDone()); assertFalse(f2.isDone()); } + @Test + void testDifferentKeyTypesWithSameFieldsDoNotCollide() { + // Event("wf-1", "wf-1") and Cancellation("wf-1") must occupy separate map entries + SignalKey eventKey = new SignalKey.Event("wf-1", "wf-1"); + SignalKey cancellationKey = new SignalKey.Cancellation("wf-1"); + + CompletableFuture f1 = registry.subscribe(eventKey); + CompletableFuture f2 = registry.subscribe(cancellationKey); + + registry.signal(cancellationKey); + + assertTrue(f2.isDone()); + assertFalse(f1.isDone()); + } + @Test void testSignalBeforeSubscribeDoesNotWake() { - // signal fires with no subscribers; subsequent subscribe gets a fresh future - registry.signal("key"); + registry.signal(KEY); - CompletableFuture f = registry.subscribe("key"); + CompletableFuture f = registry.subscribe(KEY); assertFalse(f.isDone()); - // verify it can still be signalled normally afterwards - registry.signal("key"); + registry.signal(KEY); assertTrue(f.isDone()); } @Test void testSignalIsOneShot() { - CompletableFuture f1 = registry.subscribe("key"); - registry.signal("key"); + CompletableFuture f1 = registry.subscribe(KEY); + registry.signal(KEY); assertTrue(f1.isDone()); - // second signal on same key — new subscriber should need a new signal - CompletableFuture f2 = registry.subscribe("key"); + CompletableFuture f2 = registry.subscribe(KEY); assertFalse(f2.isDone()); } + // --- Subscription / close --- + @Test void testCloseOneSubscriberDoesNotOrphanOthers() throws Exception { // Closing one subscription on a shared key must not prevent the remaining subscriber // from being woken when the signal fires (ref-counting behaviour). - SignalRegistry.Subscription sub1 = registry.subscribe("key"); - SignalRegistry.Subscription sub2 = registry.subscribe("key"); + SignalRegistry.Subscription sub1 = registry.subscribe(KEY); + SignalRegistry.Subscription sub2 = registry.subscribe(KEY); sub1.close(); // ref count drops to 1 — key must stay in map - registry.signal("key"); + registry.signal(KEY); assertTrue(sub2.isDone()); assertFalse(sub2.isCompletedExceptionally()); } @Test void testClosePreventsFutureFromBeingSignalled() throws Exception { - SignalRegistry.Subscription sub = registry.subscribe("key"); + SignalRegistry.Subscription sub = registry.subscribe(KEY); sub.close(); - registry.signal("key"); // no entry in map — should be a no-op + registry.signal(KEY); // no entry in map — should be a no-op - // sub was closed before signal so the shared future was removed from the map; - // it will never complete since nothing holds a reference to complete it boolean completed = sub.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); assertFalse(completed); @@ -132,17 +164,18 @@ void testClosePreventsFutureFromBeingSignalled() throws Exception { @Test void testNeverFutureNeverCompletes() throws Exception { - CompletableFuture f = SignalRegistry.never(); + SignalRegistry.Subscription f = SignalRegistry.never(); assertFalse(f.isDone()); boolean completed = f.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); assertFalse(completed); } + // --- Threading --- + @Test void testSubscribeBeforeSignalFromAnotherThread() throws Exception { - // Subscribe on the current thread, signal from a background thread after a delay. - CompletableFuture f = registry.subscribe("foo"); + CompletableFuture f = registry.subscribe(FOO); CompletableFuture.runAsync( () -> { @@ -151,7 +184,7 @@ void testSubscribeBeforeSignalFromAnotherThread() throws Exception { } catch (InterruptedException e) { Thread.currentThread().interrupt(); } - registry.signal("foo"); + registry.signal(FOO); }); assertTimeoutPreemptively(Duration.ofSeconds(1), (Executable) f::get); @@ -159,12 +192,11 @@ void testSubscribeBeforeSignalFromAnotherThread() throws Exception { @Test void testSignalFromMainAfterBackgroundSubscribes() throws Exception { - // Background thread subscribes and waits; main thread signals after a delay. CompletableFuture backgroundDone = new CompletableFuture<>(); CompletableFuture.runAsync( () -> { - CompletableFuture f = registry.subscribe("foo"); + CompletableFuture f = registry.subscribe(FOO); try { f.get(500, TimeUnit.MILLISECONDS); backgroundDone.complete(null); @@ -174,21 +206,20 @@ void testSignalFromMainAfterBackgroundSubscribes() throws Exception { }); Thread.sleep(100); - registry.signal("foo"); + registry.signal(FOO); assertTimeoutPreemptively(Duration.ofSeconds(1), (Executable) backgroundDone::get); } @Test void testMultipleSubscribersInSeparateThreads() throws Exception { - // Two background threads each subscribe; a single signal wakes both. CompletableFuture done1 = new CompletableFuture<>(); CompletableFuture done2 = new CompletableFuture<>(); CompletableFuture.runAsync( () -> { try { - registry.subscribe("foo").get(500, TimeUnit.MILLISECONDS); + registry.subscribe(FOO).get(500, TimeUnit.MILLISECONDS); done1.complete(null); } catch (Exception e) { done1.completeExceptionally(e); @@ -197,7 +228,7 @@ void testMultipleSubscribersInSeparateThreads() throws Exception { CompletableFuture.runAsync( () -> { try { - registry.subscribe("foo").get(500, TimeUnit.MILLISECONDS); + registry.subscribe(FOO).get(500, TimeUnit.MILLISECONDS); done2.complete(null); } catch (Exception e) { done2.completeExceptionally(e); @@ -205,7 +236,7 @@ void testMultipleSubscribersInSeparateThreads() throws Exception { }); Thread.sleep(100); - registry.signal("foo"); + registry.signal(FOO); assertTimeoutPreemptively( Duration.ofSeconds(1), @@ -217,61 +248,107 @@ void testMultipleSubscribersInSeparateThreads() throws Exception { @Test void testConcurrentSignalAndSubscribe() throws Exception { - // Stress: signal and subscribe racing from different threads should not deadlock or lose - // wakeups assertTimeoutPreemptively( Duration.ofSeconds(5), () -> { for (int i = 0; i < 1000; i++) { SignalRegistry r = new SignalRegistry(); - CompletableFuture sub = r.subscribe("key"); - CompletableFuture.runAsync(() -> r.signal("key")); + CompletableFuture sub = r.subscribe(KEY); + CompletableFuture.runAsync(() -> r.signal(KEY)); sub.get(1, TimeUnit.SECONDS); } }); } - // --- anyOf determination via checking isDone + // --- keys() --- + + @Test + void testKeysReflectsActiveSubscriptions() { + registry.subscribe(KEY_A); + registry.subscribe(KEY_B); + + Iterable keys = registry.keys(); + assertTrue(iterableContains(keys, KEY_A)); + assertTrue(iterableContains(keys, KEY_B)); + } + + @Test + void testKeysExcludesSignalledKey() { + registry.subscribe(KEY_A); + registry.subscribe(KEY_B); + registry.signal(KEY_A); + + Iterable keys = registry.keys(); + assertFalse(iterableContains(keys, KEY_A)); + assertTrue(iterableContains(keys, KEY_B)); + } + + @Test + void testKeysEmptyWhenNoSubscribers() { + assertFalse(registry.keys().iterator().hasNext()); + } + + @Test + void testKeysExcludesKeyAfterLastSubscriberCloses() { + SignalRegistry.Subscription sub = registry.subscribe(KEY); + sub.close(); + assertFalse(registry.keys().iterator().hasNext()); + } + + private static boolean iterableContains(Iterable keys, SignalKey target) { + for (SignalKey k : keys) { + if (k.equals(target)) return true; + } + return false; + } + + // --- anyOf determination via isDone (Option A) --- @Test void testCheckIsDone_notifyFires() throws Exception { - CompletableFuture onNotify = registry.subscribe("notify-key"); - CompletableFuture onCancelled = registry.subscribe("cancel-key"); - CompletableFuture onDbClosed = SignalRegistry.never(); + SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); + SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); + + CompletableFuture onNotify = registry.subscribe(notifyKey); + CompletableFuture onCancelled = registry.subscribe(cancelKey); + CompletableFuture onDbClosed = SignalRegistry.never(); - registry.signal("notify-key"); + registry.signal(notifyKey); try { CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(1, TimeUnit.SECONDS); } catch (java.util.concurrent.TimeoutException ignored) { } - assertFalse(onCancelled.isDone() || onDbClosed.isDone()); // routes to re-check-DB branch + assertFalse(onCancelled.isDone() || onDbClosed.isDone()); assertTrue(onNotify.isDone()); } @Test void testCheckIsDone_cancelledFires() throws Exception { - CompletableFuture onNotify = registry.subscribe("notify-key"); - CompletableFuture onCancelled = registry.subscribe("cancel-key"); - CompletableFuture onDbClosed = SignalRegistry.never(); + SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); + SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); - registry.signal("cancel-key"); + CompletableFuture onNotify = registry.subscribe(notifyKey); + CompletableFuture onCancelled = registry.subscribe(cancelKey); + CompletableFuture onDbClosed = SignalRegistry.never(); + + registry.signal(cancelKey); try { CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(1, TimeUnit.SECONDS); } catch (java.util.concurrent.TimeoutException ignored) { } - assertTrue(onCancelled.isDone() || onDbClosed.isDone()); // routes to return-null branch + assertTrue(onCancelled.isDone() || onDbClosed.isDone()); assertFalse(onNotify.isDone()); } @Test void testCheckIsDone_dbClosedFires() throws Exception { - CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = new CompletableFuture<>(); + CompletableFuture onDbClosed = new CompletableFuture<>(); onDbClosed.complete(null); try { @@ -279,40 +356,39 @@ void testCheckIsDone_dbClosedFires() throws Exception { } catch (java.util.concurrent.TimeoutException ignored) { } - assertTrue(onCancelled.isDone() || onDbClosed.isDone()); // routes to return-null branch + assertTrue(onCancelled.isDone() || onDbClosed.isDone()); assertFalse(onNotify.isDone()); } @Test void testCheckIsDone_timeout() throws Exception { - CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = SignalRegistry.never(); + CompletableFuture onDbClosed = SignalRegistry.never(); try { CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(50, TimeUnit.MILLISECONDS); } catch (java.util.concurrent.TimeoutException ignored) { } - assertFalse(onCancelled.isDone() || onDbClosed.isDone()); // routes to re-check-DB branch + assertFalse(onCancelled.isDone() || onDbClosed.isDone()); assertFalse(onNotify.isDone()); } - // --- anyOf determination via checking isDone via tagged dispatch + // --- anyOf determination via tagged dispatch (Option B) --- @Test void testTaggedDispatch_notifyFires() throws Exception { - CompletableFuture onNotify = registry.subscribe("notify-key"); - CompletableFuture onCancelled = registry.subscribe("cancel-key"); - CompletableFuture onDbClosed = SignalRegistry.never(); + SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); + SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); - registry.signal("notify-key"); + CompletableFuture onNotify = registry.subscribe(notifyKey); + CompletableFuture onCancelled = registry.subscribe(cancelKey); + CompletableFuture onDbClosed = SignalRegistry.never(); - enum WakeReason { - NOTIFY, - CANCELLED, - DB_CLOSED - } + registry.signal(notifyKey); + + enum WakeReason { NOTIFY, CANCELLED, DB_CLOSED } WakeReason reason = (WakeReason) CompletableFuture.anyOf( @@ -326,17 +402,16 @@ enum WakeReason { @Test void testTaggedDispatch_cancelledFires() throws Exception { - CompletableFuture onNotify = registry.subscribe("notify-key"); - CompletableFuture onCancelled = registry.subscribe("cancel-key"); - CompletableFuture onDbClosed = SignalRegistry.never(); + SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); + SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); - registry.signal("cancel-key"); + CompletableFuture onNotify = registry.subscribe(notifyKey); + CompletableFuture onCancelled = registry.subscribe(cancelKey); + CompletableFuture onDbClosed = SignalRegistry.never(); - enum WakeReason { - NOTIFY, - CANCELLED, - DB_CLOSED - } + registry.signal(cancelKey); + + enum WakeReason { NOTIFY, CANCELLED, DB_CLOSED } WakeReason reason = (WakeReason) CompletableFuture.anyOf( @@ -350,16 +425,12 @@ enum WakeReason { @Test void testTaggedDispatch_dbClosedFires() throws Exception { - CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = new CompletableFuture<>(); + CompletableFuture onDbClosed = new CompletableFuture<>(); onDbClosed.complete(null); - enum WakeReason { - NOTIFY, - CANCELLED, - DB_CLOSED - } + enum WakeReason { NOTIFY, CANCELLED, DB_CLOSED } WakeReason reason = (WakeReason) CompletableFuture.anyOf( @@ -373,15 +444,11 @@ enum WakeReason { @Test void testTaggedDispatch_timeout() throws Exception { - CompletableFuture onNotify = registry.subscribe("notify-key"); + CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = SignalRegistry.never(); + CompletableFuture onDbClosed = SignalRegistry.never(); - enum WakeReason { - NOTIFY, - CANCELLED, - DB_CLOSED - } + enum WakeReason { NOTIFY, CANCELLED, DB_CLOSED } WakeReason reason = null; try { reason = @@ -395,6 +462,6 @@ enum WakeReason { } assertFalse(onNotify.isDone()); - assertEquals(null, reason); // anyOf threw — no wake reason was tagged + assertEquals(null, reason); } } From 7e2fc40f76dc91d5d1ff9f26291c13d46d0f9d7a Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Wed, 13 May 2026 16:20:06 -0700 Subject: [PATCH 06/27] spotless --- .../transact/database/SignalRegistry.java | 5 +- .../transact/database/SignalRegistryTest.java | 70 ++++++++++++------- 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java index 80037f7c..b035771a 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java +++ b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java @@ -5,14 +5,15 @@ import java.util.concurrent.atomic.AtomicInteger; sealed interface SignalKey - permits SignalKey.Cancellation, SignalKey.Event, SignalKey.Message, SignalKey.Shutdown { + permits SignalKey.Cancellation, SignalKey.Event, SignalKey.Message, SignalKey.Shutdown { record Cancellation(String workflowId) implements SignalKey {} record Event(String workflowId, String topic) implements SignalKey {} + record Message(String workflowId, String topic) implements SignalKey {} - record Shutdown() implements SignalKey {} + record Shutdown() implements SignalKey {} } class SignalRegistry { diff --git a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java b/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java index 5f83a82d..9a50dbb2 100644 --- a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java +++ b/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java @@ -19,10 +19,10 @@ class SignalRegistryTest { SignalRegistry registry; // Reusable keys - static final SignalKey KEY = new SignalKey.Cancellation("wf-1"); + static final SignalKey KEY = new SignalKey.Cancellation("wf-1"); static final SignalKey KEY_A = new SignalKey.Cancellation("wf-a"); static final SignalKey KEY_B = new SignalKey.Cancellation("wf-b"); - static final SignalKey FOO = new SignalKey.Cancellation("foo"); + static final SignalKey FOO = new SignalKey.Cancellation("foo"); @BeforeEach void setup() { @@ -102,7 +102,7 @@ void testSignalOnlyWakesMatchingKey() { @Test void testDifferentKeyTypesWithSameFieldsDoNotCollide() { // Event("wf-1", "wf-1") and Cancellation("wf-1") must occupy separate map entries - SignalKey eventKey = new SignalKey.Event("wf-1", "wf-1"); + SignalKey eventKey = new SignalKey.Event("wf-1", "wf-1"); SignalKey cancellationKey = new SignalKey.Cancellation("wf-1"); CompletableFuture f1 = registry.subscribe(eventKey); @@ -306,12 +306,12 @@ private static boolean iterableContains(Iterable keys, SignalKey targ @Test void testCheckIsDone_notifyFires() throws Exception { - SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); - SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); + SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); + SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); - CompletableFuture onNotify = registry.subscribe(notifyKey); + CompletableFuture onNotify = registry.subscribe(notifyKey); CompletableFuture onCancelled = registry.subscribe(cancelKey); - CompletableFuture onDbClosed = SignalRegistry.never(); + CompletableFuture onDbClosed = SignalRegistry.never(); registry.signal(notifyKey); @@ -326,12 +326,12 @@ void testCheckIsDone_notifyFires() throws Exception { @Test void testCheckIsDone_cancelledFires() throws Exception { - SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); - SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); + SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); + SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); - CompletableFuture onNotify = registry.subscribe(notifyKey); + CompletableFuture onNotify = registry.subscribe(notifyKey); CompletableFuture onCancelled = registry.subscribe(cancelKey); - CompletableFuture onDbClosed = SignalRegistry.never(); + CompletableFuture onDbClosed = SignalRegistry.never(); registry.signal(cancelKey); @@ -346,9 +346,9 @@ void testCheckIsDone_cancelledFires() throws Exception { @Test void testCheckIsDone_dbClosedFires() throws Exception { - CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = new CompletableFuture<>(); + CompletableFuture onDbClosed = new CompletableFuture<>(); onDbClosed.complete(null); try { @@ -362,9 +362,9 @@ void testCheckIsDone_dbClosedFires() throws Exception { @Test void testCheckIsDone_timeout() throws Exception { - CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = SignalRegistry.never(); + CompletableFuture onDbClosed = SignalRegistry.never(); try { CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(50, TimeUnit.MILLISECONDS); @@ -382,13 +382,17 @@ void testTaggedDispatch_notifyFires() throws Exception { SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); - CompletableFuture onNotify = registry.subscribe(notifyKey); + CompletableFuture onNotify = registry.subscribe(notifyKey); CompletableFuture onCancelled = registry.subscribe(cancelKey); - CompletableFuture onDbClosed = SignalRegistry.never(); + CompletableFuture onDbClosed = SignalRegistry.never(); registry.signal(notifyKey); - enum WakeReason { NOTIFY, CANCELLED, DB_CLOSED } + enum WakeReason { + NOTIFY, + CANCELLED, + DB_CLOSED + } WakeReason reason = (WakeReason) CompletableFuture.anyOf( @@ -405,13 +409,17 @@ void testTaggedDispatch_cancelledFires() throws Exception { SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); - CompletableFuture onNotify = registry.subscribe(notifyKey); + CompletableFuture onNotify = registry.subscribe(notifyKey); CompletableFuture onCancelled = registry.subscribe(cancelKey); - CompletableFuture onDbClosed = SignalRegistry.never(); + CompletableFuture onDbClosed = SignalRegistry.never(); registry.signal(cancelKey); - enum WakeReason { NOTIFY, CANCELLED, DB_CLOSED } + enum WakeReason { + NOTIFY, + CANCELLED, + DB_CLOSED + } WakeReason reason = (WakeReason) CompletableFuture.anyOf( @@ -425,12 +433,16 @@ enum WakeReason { NOTIFY, CANCELLED, DB_CLOSED } @Test void testTaggedDispatch_dbClosedFires() throws Exception { - CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = new CompletableFuture<>(); + CompletableFuture onDbClosed = new CompletableFuture<>(); onDbClosed.complete(null); - enum WakeReason { NOTIFY, CANCELLED, DB_CLOSED } + enum WakeReason { + NOTIFY, + CANCELLED, + DB_CLOSED + } WakeReason reason = (WakeReason) CompletableFuture.anyOf( @@ -444,11 +456,15 @@ enum WakeReason { NOTIFY, CANCELLED, DB_CLOSED } @Test void testTaggedDispatch_timeout() throws Exception { - CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = SignalRegistry.never(); + CompletableFuture onDbClosed = SignalRegistry.never(); - enum WakeReason { NOTIFY, CANCELLED, DB_CLOSED } + enum WakeReason { + NOTIFY, + CANCELLED, + DB_CLOSED + } WakeReason reason = null; try { reason = From 122a0877cecb636436d3b430e3e0f7896c3c5345 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Thu, 14 May 2026 15:28:09 -0700 Subject: [PATCH 07/27] DbContext --- .../database/ApplicationVersionDAO.java | 30 +- .../dev/dbos/transact/database/DbContext.java | 21 ++ .../transact/database/ExternalStateDAO.java | 16 +- .../transact/database/NotificationsDAO.java | 102 +++--- .../dev/dbos/transact/database/QueuesDAO.java | 33 +- .../dbos/transact/database/SchedulesDAO.java | 82 ++--- .../dev/dbos/transact/database/StepsDAO.java | 126 +++----- .../dbos/transact/database/StreamsDAO.java | 42 ++- .../transact/database/SystemDatabase.java | 198 ++++-------- .../dbos/transact/database/WorkflowDAO.java | 297 ++++++++---------- 10 files changed, 396 insertions(+), 551 deletions(-) create mode 100644 transact/src/main/java/dev/dbos/transact/database/DbContext.java diff --git a/transact/src/main/java/dev/dbos/transact/database/ApplicationVersionDAO.java b/transact/src/main/java/dev/dbos/transact/database/ApplicationVersionDAO.java index 0ccd7402..4ea44d52 100644 --- a/transact/src/main/java/dev/dbos/transact/database/ApplicationVersionDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/ApplicationVersionDAO.java @@ -8,22 +8,19 @@ import java.util.List; import java.util.UUID; -import javax.sql.DataSource; - class ApplicationVersionDAO { private ApplicationVersionDAO() {} - static void createApplicationVersion(DataSource dataSource, String schema, String versionName) - throws SQLException { + static void createApplicationVersion(DbContext ctx, String versionName) throws SQLException { String sql = """ INSERT INTO "%s".application_versions (version_id, version_name) VALUES (?, ?) ON CONFLICT (version_name) DO NOTHING """ - .formatted(schema); - try (var conn = dataSource.getConnection(); + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, UUID.randomUUID().toString()); stmt.setString(2, versionName); @@ -32,16 +29,15 @@ ON CONFLICT (version_name) DO NOTHING } static void updateApplicationVersionTimestamp( - DataSource dataSource, String schema, String versionName, Instant newTimestamp) - throws SQLException { + DbContext ctx, String versionName, Instant newTimestamp) throws SQLException { String sql = """ UPDATE "%s".application_versions SET version_timestamp = ? WHERE version_name = ? """ - .formatted(schema); - try (var conn = dataSource.getConnection(); + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setLong(1, newTimestamp.toEpochMilli()); stmt.setString(2, versionName); @@ -49,17 +45,16 @@ static void updateApplicationVersionTimestamp( } } - static List listApplicationVersions(DataSource dataSource, String schema) - throws SQLException { + static List listApplicationVersions(DbContext ctx) throws SQLException { String sql = """ SELECT version_id, version_name, version_timestamp, created_at FROM "%s".application_versions ORDER BY version_timestamp DESC """ - .formatted(schema); + .formatted(ctx.schema()); List results = new ArrayList<>(); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql); var rs = stmt.executeQuery()) { while (rs.next()) { @@ -74,8 +69,7 @@ static List listApplicationVersions(DataSource dataSource, String s return results; } - static VersionInfo getLatestApplicationVersion(DataSource dataSource, String schema) - throws SQLException { + static VersionInfo getLatestApplicationVersion(DbContext ctx) throws SQLException { String sql = """ SELECT version_id, version_name, version_timestamp, created_at @@ -83,8 +77,8 @@ static VersionInfo getLatestApplicationVersion(DataSource dataSource, String sch ORDER BY version_timestamp DESC LIMIT 1 """ - .formatted(schema); - try (var conn = dataSource.getConnection(); + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql); var rs = stmt.executeQuery()) { if (rs.next()) { diff --git a/transact/src/main/java/dev/dbos/transact/database/DbContext.java b/transact/src/main/java/dev/dbos/transact/database/DbContext.java new file mode 100644 index 00000000..7bc024b2 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/DbContext.java @@ -0,0 +1,21 @@ +package dev.dbos.transact.database; + +import dev.dbos.transact.json.DBOSSerializer; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.function.BooleanSupplier; + +import javax.sql.DataSource; + +record DbContext( + DataSource dataSource, String schema, DBOSSerializer serializer, BooleanSupplier closed) { + + Connection getConnection() throws SQLException { + return dataSource.getConnection(); + } + + boolean isClosed() { + return closed.getAsBoolean(); + } +} diff --git a/transact/src/main/java/dev/dbos/transact/database/ExternalStateDAO.java b/transact/src/main/java/dev/dbos/transact/database/ExternalStateDAO.java index ef395d3d..d231a31f 100644 --- a/transact/src/main/java/dev/dbos/transact/database/ExternalStateDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/ExternalStateDAO.java @@ -6,22 +6,19 @@ import java.util.Objects; import java.util.Optional; -import javax.sql.DataSource; - class ExternalStateDAO { private ExternalStateDAO() {} static Optional getExternalState( - DataSource dataSource, String schema, String service, String workflowName, String key) - throws SQLException { + DbContext ctx, String service, String workflowName, String key) throws SQLException { final String sql = """ SELECT value, update_seq, update_time FROM "%s".event_dispatch_kv WHERE service_name = ? AND workflow_fn_name = ? AND key = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, Objects.requireNonNull(service, "service must not be null")); stmt.setString(2, Objects.requireNonNull(workflowName, "workflowName must not be null")); @@ -41,8 +38,7 @@ static Optional getExternalState( } } - static ExternalState upsertExternalState( - DataSource dataSource, String schema, ExternalState state) throws SQLException { + static ExternalState upsertExternalState(DbContext ctx, ExternalState state) throws SQLException { final var sql = """ INSERT INTO "%s".event_dispatch_kv ( @@ -58,9 +54,9 @@ ON CONFLICT (service_name, workflow_fn_name, key) ) THEN EXCLUDED.value ELSE event_dispatch_kv.value END RETURNING value, update_time, update_seq """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, Objects.requireNonNull(state.service(), "service must not be null")); stmt.setString( diff --git a/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java b/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java index 239511cc..3bee63fe 100644 --- a/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java @@ -20,8 +20,6 @@ import java.util.UUID; import java.util.concurrent.TimeUnit; -import javax.sql.DataSource; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,9 +30,7 @@ private NotificationsDAO() {} private static final Logger logger = LoggerFactory.getLogger(NotificationsDAO.class); static void send( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + DbContext ctx, String workflowId, int stepId, String destinationId, @@ -44,16 +40,17 @@ static void send( String serialization) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); var startTime = System.currentTimeMillis(); String functionName = "DBOS.send"; String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC; - try (Connection conn = dataSource.getConnection()) { + try (Connection conn = ctx.getConnection()) { conn.setAutoCommit(false); try { StepResult recordedOutput = - StepsDAO.checkStepExecutionTxn(workflowId, stepId, functionName, conn, schema); + StepsDAO.checkStepExecutionTxn(conn, ctx.schema(), workflowId, stepId, functionName); if (recordedOutput != null) { logger.debug( @@ -81,7 +78,7 @@ static void send( VALUES (?, ?, ?, ?, ?) ON CONFLICT (message_uuid) DO NOTHING """ - .formatted(schema); + .formatted(ctx.schema()); try (PreparedStatement stmt = conn.prepareStatement(sql)) { stmt.setString(1, destinationId); @@ -98,7 +95,8 @@ ON CONFLICT (message_uuid) DO NOTHING } var output = new StepResult(workflowId, stepId, functionName, null, null, null, null); - StepsDAO.recordStepResultTxn(output, startTime, System.currentTimeMillis(), conn, schema); + StepsDAO.recordStepResultTxn( + conn, ctx.schema(), output, startTime, System.currentTimeMillis()); conn.commit(); @@ -114,15 +112,14 @@ ON CONFLICT (message_uuid) DO NOTHING } static void sendDirect( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + DbContext ctx, String destinationId, Object message, String topic, String messageId, String serialization) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC; String finalMessageId = (messageId != null) ? messageId : UUID.randomUUID().toString(); var serializedMsg = SerializationUtil.serializeValue(message, serialization, serializer); @@ -134,9 +131,9 @@ static void sendDirect( VALUES (?, ?, ?, ?, ?) ON CONFLICT (message_uuid) DO NOTHING """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, destinationId); stmt.setString(2, finalTopic); @@ -153,9 +150,7 @@ ON CONFLICT (message_uuid) DO NOTHING } static Object recv( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + DbContext ctx, NotificationService notificationService, Duration dbPollingInterval, String workflowId, @@ -165,13 +160,15 @@ static Object recv( Duration timeout) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); var startTime = System.currentTimeMillis(); String functionName = "DBOS.recv"; String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC; StepResult recordedOutput; - try (Connection c = dataSource.getConnection()) { - recordedOutput = StepsDAO.checkStepExecutionTxn(workflowId, stepId, functionName, c, schema); + try (Connection c = ctx.getConnection()) { + recordedOutput = + StepsDAO.checkStepExecutionTxn(c, ctx.schema(), workflowId, stepId, functionName); } if (recordedOutput != null) { @@ -201,14 +198,15 @@ static Object recv( } while (true) { + if (ctx.isClosed()) throw new IllegalStateException("SystemDatabase is closed"); boolean hasExistingNotification; - try (Connection conn = dataSource.getConnection()) { + try (Connection conn = ctx.getConnection()) { final String sql = """ SELECT topic FROM "%s".notifications WHERE destination_uuid = ? AND topic = ? AND consumed = FALSE """ - .formatted(schema); + .formatted(ctx.schema()); try (PreparedStatement stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); @@ -225,9 +223,7 @@ static Object recv( if (!checkedDBForSleep) { actualTimeout = - StepsDAO.durableSleepDuration( - dataSource, workflowId, timeoutFunctionId, timeout, schema, serializer) - .toMillis(); + StepsDAO.durableSleepDuration(ctx, workflowId, timeoutFunctionId, timeout).toMillis(); checkedDBForSleep = true; targetTime = nowTime + actualTimeout; } @@ -246,7 +242,7 @@ static Object recv( notificationService.unregisterNotificationCondition(payload); } - try (Connection conn = dataSource.getConnection()) { + try (Connection conn = ctx.getConnection()) { conn.setAutoCommit(false); try { @@ -267,7 +263,7 @@ static Object recv( ) RETURNING message, serialization """ - .formatted(schema); + .formatted(ctx.schema()); String serializedMessage = null; String serialization = null; @@ -291,7 +287,8 @@ static Object recv( StepResult output = new StepResult( workflowId, stepId, functionName, serializedMessage, null, null, serialization); - StepsDAO.recordStepResultTxn(output, startTime, System.currentTimeMillis(), conn, schema); + StepsDAO.recordStepResultTxn( + conn, ctx.schema(), output, startTime, System.currentTimeMillis()); conn.commit(); return recvdMessage; @@ -321,7 +318,7 @@ ON CONFLICT (workflow_uuid, key) """ .formatted(schema); - try (PreparedStatement stmt = conn.prepareStatement(eventSql)) { + try (var stmt = conn.prepareStatement(eventSql)) { stmt.setString(1, workflowId); stmt.setString(2, key); stmt.setString(3, message); @@ -338,7 +335,7 @@ ON CONFLICT (workflow_uuid, key, function_id) """ .formatted(schema); - try (PreparedStatement stmt = conn.prepareStatement(eventHistorySql)) { + try (var stmt = conn.prepareStatement(eventHistorySql)) { stmt.setString(1, workflowId); stmt.setInt(2, functionId); stmt.setString(3, key); @@ -349,9 +346,7 @@ ON CONFLICT (workflow_uuid, key, function_id) } static void setEvent( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + DbContext ctx, String workflowId, int functionId, String key, @@ -360,18 +355,20 @@ static void setEvent( String serialization) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); var startTime = System.currentTimeMillis(); String functionName = "DBOS.setEvent"; SerializationUtil.SerializedResult serializedResult = SerializationUtil.serializeValue(message, serialization, serializer); - try (Connection conn = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { conn.setAutoCommit(false); try { if (asStep) { var recordedOutput = - StepsDAO.checkStepExecutionTxn(workflowId, functionId, functionName, conn, schema); + StepsDAO.checkStepExecutionTxn( + conn, ctx.schema(), workflowId, functionId, functionName); if (recordedOutput != null) { logger.debug( "Replaying setEvent, workflow: {}, step: {}, key: {}", workflowId, functionId, key); @@ -385,7 +382,7 @@ static void setEvent( setEvent( conn, - schema, + ctx.schema(), workflowId, functionId, key, @@ -395,7 +392,8 @@ static void setEvent( if (asStep) { StepResult output = new StepResult(workflowId, functionId, functionName, null, null, null, null); - StepsDAO.recordStepResultTxn(output, startTime, System.currentTimeMillis(), conn, schema); + StepsDAO.recordStepResultTxn( + conn, ctx.schema(), output, startTime, System.currentTimeMillis()); } conn.commit(); @@ -409,9 +407,7 @@ static void setEvent( } static Object getEvent( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + DbContext ctx, NotificationService notificationService, Duration dbPollingInterval, String targetUuid, @@ -420,15 +416,16 @@ static Object getEvent( GetWorkflowEventContext callerCtx) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); var startTime = System.currentTimeMillis(); String functionName = "DBOS.getEvent"; if (callerCtx != null) { StepResult recordedOutput; - try (Connection conn = dataSource.getConnection()) { + try (Connection conn = ctx.getConnection()) { recordedOutput = StepsDAO.checkStepExecutionTxn( - callerCtx.workflowId(), callerCtx.functionId(), functionName, conn, schema); + conn, ctx.schema(), callerCtx.workflowId(), callerCtx.functionId(), functionName); } if (recordedOutput != null) { @@ -455,7 +452,7 @@ static Object getEvent( """ SELECT value, serialization FROM "%s".workflow_events WHERE workflow_uuid = ? AND key = ? """ - .formatted(schema); + .formatted(ctx.schema()); double actualTimeout = Objects.requireNonNull(timeout, "getEvent timeout cannot be null").toMillis(); @@ -464,7 +461,8 @@ static Object getEvent( var hasExistingNotification = false; while (true) { - try (Connection conn = dataSource.getConnection(); + if (ctx.isClosed()) throw new IllegalStateException("SystemDatabase is closed"); + try (Connection conn = ctx.getConnection(); PreparedStatement stmt = conn.prepareStatement(sql)) { stmt.setString(1, targetUuid); @@ -488,12 +486,7 @@ static Object getEvent( if (callerCtx != null && !checkedDBForSleep) { actualTimeout = StepsDAO.durableSleepDuration( - dataSource, - callerCtx.workflowId(), - callerCtx.timeoutFunctionId(), - timeout, - schema, - serializer) + ctx, callerCtx.workflowId(), callerCtx.timeoutFunctionId(), timeout) .toMillis(); targetTime = System.currentTimeMillis() + actualTimeout; checkedDBForSleep = true; @@ -523,8 +516,7 @@ static Object getEvent( null, toSaveSer.serialization()) .withOutput(toSaveSer.serializedValue()); - StepsDAO.recordStepResultTxn( - dataSource, output, startTime, System.currentTimeMillis(), schema); + StepsDAO.recordStepResultTxn(ctx, output, startTime, System.currentTimeMillis()); } return value; @@ -535,9 +527,9 @@ static Object getEvent( } } - static List getAllNotifications( - DataSource dataSource, String schema, DBOSSerializer serializer, String workflowId) + static List getAllNotifications(DbContext ctx, String workflowId) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); var sql = """ SELECT topic, message, serialization, created_at_epoch_ms, consumed @@ -545,10 +537,10 @@ static List getAllNotifications( WHERE destination_uuid = ? ORDER BY created_at_epoch_ms """ - .formatted(schema); + .formatted(ctx.schema()); var notifications = new ArrayList(); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); try (var rs = stmt.executeQuery()) { diff --git a/transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java b/transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java index c0adfc48..654b82d2 100644 --- a/transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java @@ -14,8 +14,6 @@ import java.util.List; import java.util.Map; -import javax.sql.DataSource; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -26,19 +24,14 @@ private QueuesDAO() {} private static final Logger logger = LoggerFactory.getLogger(QueuesDAO.class); static List getAndStartQueuedWorkflows( - DataSource dataSource, - String schema, - Queue queue, - String executorId, - String appVersion, - String partitionKey) + DbContext ctx, Queue queue, String executorId, String appVersion, String partitionKey) throws SQLException { if (partitionKey != null && partitionKey.length() == 0) { partitionKey = null; } - try (Connection connection = dataSource.getConnection()) { + try (Connection connection = ctx.getConnection()) { connection.setAutoCommit(false); try (Statement stmt = connection.createStatement()) { @@ -59,7 +52,7 @@ SELECT COUNT(*) AND status NOT IN (?, ?) AND started_at_epoch_ms > ? """ - .formatted(schema); + .formatted(ctx.schema()); if (partitionKey != null) { limiterQuery += " AND queue_partition_key = ?"; } @@ -94,7 +87,7 @@ SELECT executor_id, COUNT(*) as task_count FROM "%s".workflow_status WHERE queue_name = ? AND status = ? """ - .formatted(schema); + .formatted(ctx.schema()); if (partitionKey != null) { pendingQuery += " AND queue_partition_key = ?"; } @@ -163,7 +156,7 @@ SELECT executor_id, COUNT(*) as task_count AND status = ? AND (application_version = ? OR application_version IS NULL) """ - .formatted(schema); + .formatted(ctx.schema()); if (partitionKey != null) { query += " AND queue_partition_key = ?"; } @@ -221,7 +214,7 @@ THEN EXTRACT(epoch FROM NOW()) * 1000 + workflow_timeout_ms END WHERE workflow_uuid = ? """ - .formatted(schema); + .formatted(ctx.schema()); try (var ps = connection.prepareStatement(updateQuery)) { for (var id : dequeuedWorkflowIds) { @@ -251,8 +244,7 @@ THEN EXTRACT(epoch FROM NOW()) * 1000 + workflow_timeout_ms } } - static boolean clearQueueAssignment(DataSource dataSource, String schema, String workflowId) - throws SQLException { + static boolean clearQueueAssignment(DbContext ctx, String workflowId) throws SQLException { final String sql = """ @@ -260,8 +252,8 @@ static boolean clearQueueAssignment(DataSource dataSource, String schema, String SET started_at_epoch_ms = NULL, status = ? WHERE workflow_uuid = ? AND queue_name IS NOT NULL AND status = ? """ - .formatted(schema); - try (Connection connection = dataSource.getConnection(); + .formatted(ctx.schema()); + try (Connection connection = ctx.getConnection(); PreparedStatement stmt = connection.prepareStatement(sql)) { stmt.setString(1, WorkflowState.ENQUEUED.name()); stmt.setString(2, workflowId); @@ -272,8 +264,7 @@ static boolean clearQueueAssignment(DataSource dataSource, String schema, String } } - static List getQueuePartitions(DataSource dataSource, String schema, String queueName) - throws SQLException { + static List getQueuePartitions(DbContext ctx, String queueName) throws SQLException { final String sql = """ @@ -283,9 +274,9 @@ static List getQueuePartitions(DataSource dataSource, String schema, Str AND status = ? AND queue_partition_key IS NOT NULL """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection connection = dataSource.getConnection(); + try (Connection connection = ctx.getConnection(); PreparedStatement stmt = connection.prepareStatement(sql)) { stmt.setString(1, queueName); stmt.setString(2, WorkflowState.ENQUEUED.name()); diff --git a/transact/src/main/java/dev/dbos/transact/database/SchedulesDAO.java b/transact/src/main/java/dev/dbos/transact/database/SchedulesDAO.java index 8e2f0715..3619c5f4 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SchedulesDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/SchedulesDAO.java @@ -20,17 +20,13 @@ import java.util.StringJoiner; import java.util.UUID; -import javax.sql.DataSource; - class SchedulesDAO { private SchedulesDAO() {} - static void createSchedule( - DataSource dataSource, String schema, DBOSSerializer serializer, WorkflowSchedule schedule) - throws SQLException { - try (Connection conn = dataSource.getConnection()) { - createSchedule(conn, schema, serializer, schedule); + static void createSchedule(DbContext ctx, WorkflowSchedule schedule) throws SQLException { + try (Connection conn = ctx.getConnection()) { + createSchedule(conn, ctx.schema(), ctx.serializer(), schedule); } } @@ -85,14 +81,13 @@ static void createSchedule( } static List listSchedules( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + DbContext ctx, List statuses, List workflowNames, List scheduleNamePrefixes) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); StringBuilder sql = new StringBuilder( """ @@ -102,7 +97,7 @@ static List listSchedules( FROM "%s".workflow_schedules WHERE TRUE """ - .formatted(schema)); + .formatted(ctx.schema())); List params = new ArrayList<>(); @@ -124,7 +119,7 @@ static List listSchedules( sql.append(orClauses).append(")"); } - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement ps = conn.prepareStatement(sql.toString())) { List arrays = new ArrayList<>(); try { @@ -153,9 +148,8 @@ static List listSchedules( } } - static Optional getSchedule( - DataSource dataSource, String schema, DBOSSerializer serializer, String name) - throws SQLException { + static Optional getSchedule(DbContext ctx, String name) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); String sql = """ SELECT schedule_id, schedule_name, workflow_name, workflow_class_name, @@ -164,9 +158,9 @@ static Optional getSchedule( FROM "%s".workflow_schedules WHERE schedule_name = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { ps.setString(1, name); try (ResultSet rs = ps.executeQuery()) { @@ -178,25 +172,23 @@ static Optional getSchedule( } } - static void pauseSchedule(DataSource dataSource, String schema, String name) throws SQLException { - setScheduleStatus(dataSource, schema, name, ScheduleStatus.PAUSED); + static void pauseSchedule(DbContext ctx, String name) throws SQLException { + setScheduleStatus(ctx, name, ScheduleStatus.PAUSED); } - static void resumeSchedule(DataSource dataSource, String schema, String name) - throws SQLException { - setScheduleStatus(dataSource, schema, name, ScheduleStatus.ACTIVE); + static void resumeSchedule(DbContext ctx, String name) throws SQLException { + setScheduleStatus(ctx, name, ScheduleStatus.ACTIVE); } - private static void setScheduleStatus( - DataSource dataSource, String schema, String name, ScheduleStatus status) + private static void setScheduleStatus(DbContext ctx, String name, ScheduleStatus status) throws SQLException { String sql = """ UPDATE "%s".workflow_schedules SET status = ? WHERE schedule_name = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { ps.setString(1, status.name()); ps.setString(2, name); @@ -204,15 +196,15 @@ private static void setScheduleStatus( } } - static void updateScheduleLastFiredAt( - DataSource dataSource, String schema, String name, Instant lastFiredAt) throws SQLException { + static void updateScheduleLastFiredAt(DbContext ctx, String name, Instant lastFiredAt) + throws SQLException { String sql = """ UPDATE "%s".workflow_schedules SET last_fired_at = ? WHERE schedule_name = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { ps.setString(1, lastFiredAt != null ? lastFiredAt.toString() : null); ps.setString(2, name); @@ -220,10 +212,9 @@ static void updateScheduleLastFiredAt( } } - static void deleteSchedule(DataSource dataSource, String schema, String name) - throws SQLException { - try (Connection conn = dataSource.getConnection()) { - deleteSchedule(conn, schema, name); + static void deleteSchedule(DbContext ctx, String name) throws SQLException { + try (var conn = ctx.getConnection()) { + deleteSchedule(conn, ctx.schema(), name); } } @@ -234,27 +225,22 @@ static void deleteSchedule(Connection conn, String schema, String name) throws S """ .formatted(schema); - try (PreparedStatement ps = conn.prepareStatement(sql)) { - ps.setString(1, name); - ps.executeUpdate(); + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, name); + stmt.executeUpdate(); } } - static void applySchedules( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - List schedules) - throws SQLException { - try (Connection conn = dataSource.getConnection()) { + static void applySchedules(DbContext ctx, List schedules) throws SQLException { + try (var conn = ctx.getConnection()) { conn.setAutoCommit(false); try { for (WorkflowSchedule schedule : schedules) { - deleteSchedule(conn, schema, schedule.scheduleName()); + deleteSchedule(conn, ctx.schema(), schedule.scheduleName()); createSchedule( conn, - schema, - serializer, + ctx.schema(), + ctx.serializer(), schedule .withScheduleId(UUID.randomUUID().toString()) .withStatus(ScheduleStatus.ACTIVE) @@ -264,8 +250,6 @@ static void applySchedules( } catch (SQLException e) { conn.rollback(); throw e; - } finally { - conn.setAutoCommit(true); } } } diff --git a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java b/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java index f6cc2052..ab82dfe4 100644 --- a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java @@ -16,8 +16,6 @@ import java.util.List; import java.util.Objects; -import javax.sql.DataSource; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,24 +26,16 @@ private StepsDAO() {} private static final Logger logger = LoggerFactory.getLogger(StepsDAO.class); static void recordStepResultTxn( - DataSource dataSource, - StepResult result, - long startTimeEpochMs, - long endTimeEpochMs, - String schema) + DbContext ctx, StepResult result, long startTimeEpochMs, long endTimeEpochMs) throws SQLException { - try (Connection connection = dataSource.getConnection()) { - recordStepResultTxn(result, startTimeEpochMs, endTimeEpochMs, connection, schema); + try (var conn = ctx.getConnection()) { + recordStepResultTxn(conn, ctx.schema(), result, startTimeEpochMs, endTimeEpochMs); } DebugTriggers.debugTriggerPoint(DebugTriggers.DEBUG_TRIGGER_STEP_COMMIT); } static void recordStepResultTxn( - StepResult result, - Long startTimeEpochMs, - Long endTimeEpochMs, - Connection connection, - String schema) + Connection conn, String schema, StepResult result, Long startTimeEpochMs, Long endTimeEpochMs) throws SQLException { Objects.requireNonNull(schema); @@ -58,33 +48,33 @@ static void recordStepResultTxn( """ .formatted(schema); - try (PreparedStatement pstmt = connection.prepareStatement(sql)) { - pstmt.setString(1, result.workflowId()); - pstmt.setInt(2, result.stepId()); - pstmt.setString(3, result.stepName()); + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, result.workflowId()); + stmt.setInt(2, result.stepId()); + stmt.setString(3, result.stepName()); if (result.output() != null) { - pstmt.setString(4, result.output()); + stmt.setString(4, result.output()); } else { - pstmt.setNull(4, Types.LONGVARCHAR); + stmt.setNull(4, Types.LONGVARCHAR); } if (result.error() != null) { - pstmt.setString(5, result.error()); + stmt.setString(5, result.error()); } else { - pstmt.setNull(5, Types.LONGVARCHAR); + stmt.setNull(5, Types.LONGVARCHAR); } if (result.childWorkflowId() != null) { - pstmt.setString(6, result.childWorkflowId()); + stmt.setString(6, result.childWorkflowId()); } else { - pstmt.setNull(6, Types.VARCHAR); + stmt.setNull(6, Types.VARCHAR); } - pstmt.setObject(7, startTimeEpochMs); - pstmt.setObject(8, endTimeEpochMs); + stmt.setObject(7, startTimeEpochMs); + stmt.setObject(8, endTimeEpochMs); - try (ResultSet rs = pstmt.executeQuery()) { + try (ResultSet rs = stmt.executeQuery()) { if (rs.next() && endTimeEpochMs != null) { long completedAt = rs.getLong("completed_at_epoch_ms"); if (completedAt != endTimeEpochMs) { @@ -107,7 +97,7 @@ static void recordStepResultTxn( } static StepResult checkStepExecutionTxn( - String workflowId, int functionId, String functionName, Connection connection, String schema) + Connection conn, String schema, String workflowId, int functionId, String functionName) throws SQLException, DBOSWorkflowCancelledException, DBOSUnexpectedStepException { Objects.requireNonNull(schema); @@ -118,7 +108,7 @@ static StepResult checkStepExecutionTxn( .formatted(schema); String workflowStatus = null; - try (PreparedStatement pstmt = connection.prepareStatement(sql)) { + try (var pstmt = conn.prepareStatement(sql)) { pstmt.setString(1, workflowId); try (ResultSet rs = pstmt.executeQuery()) { if (rs.next()) { @@ -147,7 +137,7 @@ static StepResult checkStepExecutionTxn( StepResult recordedResult = null; String recordedFunctionName = null; - try (PreparedStatement pstmt = connection.prepareStatement(operationOutputSql)) { + try (var pstmt = conn.prepareStatement(operationOutputSql)) { pstmt.setString(1, workflowId); pstmt.setInt(2, functionId); try (ResultSet rs = pstmt.executeQuery()) { @@ -176,28 +166,22 @@ static StepResult checkStepExecutionTxn( } static List listWorkflowSteps( - DataSource dataSource, - String workflowId, - Boolean loadOutput, - Integer limit, - Integer offset, - String schema, - DBOSSerializer serializer) + DbContext ctx, String workflowId, Boolean loadOutput, Integer limit, Integer offset) throws SQLException { - try (Connection connection = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { return listWorkflowSteps( - connection, workflowId, loadOutput, limit, offset, schema, serializer); + conn, ctx.schema(), ctx.serializer(), workflowId, loadOutput, limit, offset); } } static List listWorkflowSteps( - Connection connection, + Connection conn, + String schema, + DBOSSerializer serializer, String workflowId, Boolean loadOutput, Integer limit, - Integer offset, - String schema, - DBOSSerializer serializer) + Integer offset) throws SQLException { StringBuilder sqlBuilder = @@ -221,7 +205,7 @@ static List listWorkflowSteps( List steps = new ArrayList<>(); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (var stmt = conn.prepareStatement(sql)) { int paramIndex = 1; stmt.setString(paramIndex++, workflowId); @@ -277,16 +261,9 @@ static List listWorkflowSteps( return steps; } - static void sleep( - DataSource dataSource, - String workflowUuid, - int functionId, - Duration duration, - String schema, - DBOSSerializer serializer) + static void sleep(DbContext ctx, String workflowUuid, int functionId, Duration duration) throws SQLException { - var sleepDuration = - durableSleepDuration(dataSource, workflowUuid, functionId, duration, schema, serializer); + var sleepDuration = durableSleepDuration(ctx, workflowUuid, functionId, duration); logger.debug("Sleeping for duration {}", sleepDuration); try { Thread.sleep(sleepDuration.toMillis()); @@ -296,7 +273,7 @@ static void sleep( } } - static String getCheckpointName(Connection conn, String workflowId, int functionId, String schema) + static String getCheckpointName(Connection conn, String schema, String workflowId, int functionId) throws SQLException { var sql = """ @@ -306,10 +283,10 @@ static String getCheckpointName(Connection conn, String workflowId, int function """ .formatted(schema); - try (var ps = conn.prepareStatement(sql)) { - ps.setString(1, workflowId); - ps.setInt(2, functionId); - try (var rs = ps.executeQuery()) { + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + stmt.setInt(2, functionId); + try (var rs = stmt.executeQuery()) { if (rs.next()) { return rs.getString("function_name"); } else { @@ -319,15 +296,14 @@ static String getCheckpointName(Connection conn, String workflowId, int function } } - static boolean patch( - DataSource dataSource, String workflowId, int functionId, String patchName, String schema) + static boolean patch(DbContext ctx, String workflowId, int functionId, String patchName) throws SQLException { Objects.requireNonNull(patchName, "patchName cannot be null"); - try (Connection conn = dataSource.getConnection()) { - var checkpointName = getCheckpointName(conn, workflowId, functionId, schema); + try (var conn = ctx.getConnection()) { + var checkpointName = getCheckpointName(conn, ctx.schema(), workflowId, functionId); if (checkpointName == null) { var output = new StepResult(workflowId, functionId, patchName, null, null, null, null); - recordStepResultTxn(output, System.currentTimeMillis(), null, conn, schema); + recordStepResultTxn(conn, ctx.schema(), output, System.currentTimeMillis(), null); return true; } else { return patchName.equals(checkpointName); @@ -335,34 +311,28 @@ static boolean patch( } } - static boolean deprecatePatch( - DataSource dataSource, String workflowId, int functionId, String patchName, String schema) + static boolean deprecatePatch(DbContext ctx, String workflowId, int functionId, String patchName) throws SQLException { Objects.requireNonNull(patchName, "patchName cannot be null"); - try (Connection conn = dataSource.getConnection()) { - var checkpointName = getCheckpointName(conn, workflowId, functionId, schema); + try (var conn = ctx.getConnection()) { + var checkpointName = getCheckpointName(conn, ctx.schema(), workflowId, functionId); return patchName.equals(checkpointName); } } static Duration durableSleepDuration( - DataSource dataSource, - String workflowUuid, - int functionId, - Duration duration, - String schema, - DBOSSerializer serializer) - throws SQLException { + DbContext ctx, String workflowUuid, int functionId, Duration duration) throws SQLException { - Objects.requireNonNull(schema); + DBOSSerializer serializer = ctx.serializer(); + Objects.requireNonNull(ctx.schema()); var startTime = System.currentTimeMillis(); String functionName = "DBOS.sleep"; StepResult recordedOutput; - try (Connection connection = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { recordedOutput = - checkStepExecutionTxn(workflowUuid, functionId, functionName, connection, schema); + checkStepExecutionTxn(conn, ctx.schema(), workflowUuid, functionId, functionName); } long endTime; @@ -398,7 +368,7 @@ static Duration durableSleepDuration( null, null, serializedValue.serialization()); - recordStepResultTxn(dataSource, output, startTime, (long) endTime, schema); + recordStepResultTxn(ctx, output, startTime, (long) endTime); } catch (DBOSWorkflowExecutionConflictException e) { logger.error("Error recording sleep", e); } diff --git a/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java b/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java index 59f47403..b0881766 100644 --- a/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java @@ -10,29 +10,25 @@ import java.util.List; import java.util.Map; -import javax.sql.DataSource; - class StreamsDAO { private StreamsDAO() {} static void writeStreamFromStep( - DataSource dataSource, - String schema, + DbContext ctx, String workflowId, int functionId, String key, Object value, String serializationFormat) throws SQLException { - try (Connection conn = dataSource.getConnection()) { - insertStream(conn, schema, workflowId, functionId, key, value, serializationFormat); + try (var conn = ctx.getConnection()) { + insertStream(conn, ctx.schema(), workflowId, functionId, key, value, serializationFormat); } } static void writeStreamFromWorkflow( - DataSource dataSource, - String schema, + DbContext ctx, String workflowId, int functionId, String key, @@ -43,12 +39,13 @@ static void writeStreamFromWorkflow( STREAM_CLOSED_SENTINEL.equals(value) ? "DBOS.closeStream" : "DBOS.writeStream"; long startTime = System.currentTimeMillis(); - try (Connection conn = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { conn.setAutoCommit(false); try { StepResult recordedOutput = - StepsDAO.checkStepExecutionTxn(workflowId, functionId, functionName, conn, schema); + StepsDAO.checkStepExecutionTxn( + conn, ctx.schema(), workflowId, functionId, functionName); if (recordedOutput != null) { logger.debug("Replaying writeStream, id: {}, key: {}", functionId, key); @@ -58,10 +55,11 @@ static void writeStreamFromWorkflow( logger.debug("Running writeStream, id: {}, key: {}", functionId, key); } - insertStream(conn, schema, workflowId, functionId, key, value, serializationFormat); + insertStream(conn, ctx.schema(), workflowId, functionId, key, value, serializationFormat); var output = new StepResult(workflowId, functionId, functionName, null, null, null, null); - StepsDAO.recordStepResultTxn(output, startTime, System.currentTimeMillis(), conn, schema); + StepsDAO.recordStepResultTxn( + conn, ctx.schema(), output, startTime, System.currentTimeMillis()); conn.commit(); @@ -128,15 +126,13 @@ SELECT COALESCE(MAX("offset"), -1) + 1 } } - static void closeStream( - DataSource dataSource, String schema, String workflowId, int functionId, String key) + static void closeStream(DbContext ctx, String workflowId, int functionId, String key) throws SQLException { writeStreamFromWorkflow( - dataSource, schema, workflowId, functionId, key, STREAM_CLOSED_SENTINEL, "portable_json"); + ctx, workflowId, functionId, key, STREAM_CLOSED_SENTINEL, "portable_json"); } - static Object readStream( - DataSource dataSource, String schema, String workflowId, String key, int offset) + static Object readStream(DbContext ctx, String workflowId, String key, int offset) throws SQLException { String sql = """ @@ -144,9 +140,9 @@ static Object readStream( FROM "%s".streams WHERE workflow_uuid = ? AND key = ? AND "offset" = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); stmt.setString(2, key); @@ -167,8 +163,8 @@ static Object readStream( } } - static Map> getAllStreamEntries( - DataSource dataSource, String schema, String workflowId) throws SQLException { + static Map> getAllStreamEntries(DbContext ctx, String workflowId) + throws SQLException { String sql = """ SELECT key, value, serialization @@ -176,10 +172,10 @@ static Map> getAllStreamEntries( WHERE workflow_uuid = ? ORDER BY key, "offset" """ - .formatted(schema); + .formatted(ctx.schema()); var streams = new LinkedHashMap>(); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); try (var rs = stmt.executeQuery()) { diff --git a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java index e9e8df99..5c1b8805 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -24,6 +24,7 @@ import java.time.Duration; import java.time.Instant; import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; import javax.sql.DataSource; @@ -40,11 +41,10 @@ public static String sanitizeSchema(String schema) { return Objects.requireNonNullElse(schema, Constants.DB_SCHEMA).replace("\0", ""); } - private final DataSource dataSource; - private final String schema; + private final DbContext ctx; private final boolean created; - private final DBOSSerializer serializer; + private final AtomicBoolean closed = new AtomicBoolean(false); private final NotificationService notificationService; private Duration dbPollingInterval = Duration.ofSeconds(1); @@ -69,10 +69,8 @@ private SystemDatabase( throw new IllegalArgumentException("Schema name must not contain double quotes"); } - this.schema = schema; - this.dataSource = dataSource; + this.ctx = new DbContext(dataSource, schema, serializer, this.closed::get); this.created = created; - this.serializer = serializer; notificationService = new NotificationService(dataSource); } @@ -108,7 +106,7 @@ public static SystemDatabase create(DBOSConfig config) { } Optional getConfig() { - if (dataSource instanceof HikariDataSource hds) { + if (ctx.dataSource() instanceof HikariDataSource hds) { return Optional.of(hds); } return Optional.empty(); @@ -142,8 +140,9 @@ public static HikariDataSource createDataSource(String url, String user, String @Override public void close() { + closed.set(true); notificationService.stop(); - if (created && dataSource instanceof HikariDataSource hikariDataSource) { + if (created && ctx.dataSource() instanceof HikariDataSource hikariDataSource) { hikariDataSource.close(); } } @@ -198,6 +197,9 @@ private T dbRetry(SqlSupplier supplier) { final int MAX_RETRIES = 20; int attempt = 0; while (true) { + if (closed.get()) { + throw new IllegalStateException("SystemDatabase is closed"); + } try { return supplier.get(); } catch (SQLException e) { @@ -207,7 +209,7 @@ private T dbRetry(SqlSupplier supplier) { } if (e instanceof SQLRecoverableException || isConnectionFailure(e)) { logger.warn("Recoverable connection error. Resetting client pool.", e); - if (dataSource instanceof HikariDataSource hikariDataSource) { + if (ctx.dataSource() instanceof HikariDataSource hikariDataSource) { hikariDataSource.getHikariPoolMXBean().softEvictConnections(); } waitForRecovery(attempt, 2000); @@ -256,14 +258,7 @@ public WorkflowInitResult initWorkflowStatus( return dbRetry( () -> WorkflowDAO.initWorkflowStatus( - dataSource, - schema, - serializer, - initStatus, - maxRetries, - isRecoveryRequest, - isDequeuedRequest, - ownerXid)); + ctx, initStatus, maxRetries, isRecoveryRequest, isDequeuedRequest, ownerXid)); } /** @@ -273,7 +268,7 @@ public WorkflowInitResult initWorkflowStatus( * @param result output serialized as json */ public void recordWorkflowOutput(String workflowId, String result) { - dbRetry(() -> WorkflowDAO.recordWorkflowOutput(dataSource, schema, workflowId, result)); + dbRetry(() -> WorkflowDAO.recordWorkflowOutput(ctx, workflowId, result)); } /** @@ -283,77 +278,67 @@ public void recordWorkflowOutput(String workflowId, String result) { * @param error output serialized as json */ public void recordWorkflowError(String workflowId, String error) { - dbRetry(() -> WorkflowDAO.recordWorkflowError(dataSource, schema, workflowId, error)); + dbRetry(() -> WorkflowDAO.recordWorkflowError(ctx, workflowId, error)); } public WorkflowStatus getWorkflowStatus(String workflowId) { - return dbRetry(() -> WorkflowDAO.getWorkflowStatus(dataSource, schema, serializer, workflowId)); + return dbRetry(() -> WorkflowDAO.getWorkflowStatus(ctx, workflowId)); } public String getWorkflowSerialization(String workflowId) { - return dbRetry(() -> WorkflowDAO.getWorkflowSerialization(dataSource, schema, workflowId)); + return dbRetry(() -> WorkflowDAO.getWorkflowSerialization(ctx, workflowId)); } public List listWorkflows(ListWorkflowsInput input) { - return dbRetry(() -> WorkflowDAO.listWorkflows(dataSource, schema, serializer, input)); + return dbRetry(() -> WorkflowDAO.listWorkflows(ctx, input)); } public List getWorkflowAggregates(GetWorkflowAggregatesInput input) { - return dbRetry(() -> WorkflowDAO.getWorkflowAggregates(dataSource, schema, serializer, input)); + return dbRetry(() -> WorkflowDAO.getWorkflowAggregates(ctx, input)); } public List getPendingWorkflows(List executorIds, String appVersion) { - return dbRetry( - () -> - WorkflowDAO.getPendingWorkflows( - dataSource, schema, serializer, executorIds, appVersion)); + return dbRetry(() -> WorkflowDAO.getPendingWorkflows(ctx, executorIds, appVersion)); } public boolean clearQueueAssignment(String workflowId) { - return dbRetry(() -> QueuesDAO.clearQueueAssignment(dataSource, schema, workflowId)); + return dbRetry(() -> QueuesDAO.clearQueueAssignment(ctx, workflowId)); } public List getQueuePartitions(String queueName) { - return dbRetry(() -> QueuesDAO.getQueuePartitions(dataSource, schema, queueName)); + return dbRetry(() -> QueuesDAO.getQueuePartitions(ctx, queueName)); } public StepResult checkStepExecutionTxn(String workflowId, int functionId, String functionName) { return dbRetry( () -> { - try (Connection connection = dataSource.getConnection()) { + try (Connection connection = ctx.getConnection()) { return StepsDAO.checkStepExecutionTxn( - workflowId, functionId, functionName, connection, this.schema); + connection, ctx.schema(), workflowId, functionId, functionName); } }); } public void recordStepResultTxn(StepResult result, long startTime) { var et = System.currentTimeMillis(); - dbRetry(() -> StepsDAO.recordStepResultTxn(dataSource, result, startTime, et, this.schema)); + dbRetry(() -> StepsDAO.recordStepResultTxn(ctx, result, startTime, et)); } public List listWorkflowSteps( String workflowId, Boolean loadOutput, Integer limit, Integer offset) { - return dbRetry( - () -> - StepsDAO.listWorkflowSteps( - dataSource, workflowId, loadOutput, limit, offset, this.schema, this.serializer)); + return dbRetry(() -> StepsDAO.listWorkflowSteps(ctx, workflowId, loadOutput, limit, offset)); } public Result awaitWorkflowResult(String workflowId) { - return dbRetry( - () -> - WorkflowDAO.awaitWorkflowResult( - dataSource, schema, serializer, dbPollingInterval, workflowId)); + return dbRetry(() -> WorkflowDAO.awaitWorkflowResult(ctx, dbPollingInterval, workflowId)); } public List getAndStartQueuedWorkflows( Queue queue, String executorId, String appVersion, String partitionKey) { return dbRetry( () -> - QueuesDAO.getAndStartQueuedWorkflows( - dataSource, schema, queue, executorId, appVersion, partitionKey)); + QueuesDAO.getAndStartQueuedWorkflows(ctx, queue, executorId, appVersion, partitionKey)); } public void recordChildWorkflow( @@ -365,12 +350,11 @@ public void recordChildWorkflow( dbRetry( () -> WorkflowDAO.recordChildWorkflow( - dataSource, schema, parentId, childId, functionId, functionName, startTime)); + ctx, parentId, childId, functionId, functionName, startTime)); } public Optional checkChildWorkflow(String workflowUuid, int functionId) { - return dbRetry( - () -> WorkflowDAO.checkChildWorkflow(dataSource, schema, workflowUuid, functionId)); + return dbRetry(() -> WorkflowDAO.checkChildWorkflow(ctx, workflowUuid, functionId)); } public void send( @@ -384,16 +368,7 @@ public void send( dbRetry( () -> NotificationsDAO.send( - dataSource, - schema, - serializer, - workflowId, - stepId, - destinationId, - message, - topic, - messageId, - serialization)); + ctx, workflowId, stepId, destinationId, message, topic, messageId, serialization)); } public void sendDirect( @@ -401,14 +376,7 @@ public void sendDirect( dbRetry( () -> NotificationsDAO.sendDirect( - dataSource, - schema, - serializer, - destinationId, - message, - topic, - messageId, - serialization)); + ctx, destinationId, message, topic, messageId, serialization)); } public Object recv( @@ -416,9 +384,7 @@ public Object recv( return dbRetry( () -> NotificationsDAO.recv( - dataSource, - schema, - serializer, + ctx, notificationService, dbPollingInterval, workflowId, @@ -439,15 +405,7 @@ public void setEvent( dbRetry( () -> NotificationsDAO.setEvent( - dataSource, - schema, - serializer, - workflowId, - functionId, - key, - message, - asStep, - serialization)); + ctx, workflowId, functionId, key, message, asStep, serialization)); } public Object getEvent( @@ -456,77 +414,66 @@ public Object getEvent( return dbRetry( () -> NotificationsDAO.getEvent( - dataSource, - schema, - serializer, - notificationService, - dbPollingInterval, - targetId, - key, - timeout, - callerCtx)); + ctx, notificationService, dbPollingInterval, targetId, key, timeout, callerCtx)); } public void sleep(String workflowId, int functionId, Duration duration) { - dbRetry(() -> StepsDAO.sleep(dataSource, workflowId, functionId, duration, schema, serializer)); + dbRetry(() -> StepsDAO.sleep(ctx, workflowId, functionId, duration)); } public void cancelWorkflows(List workflowIds) { - dbRetry(() -> WorkflowDAO.cancelWorkflows(dataSource, schema, workflowIds)); + dbRetry(() -> WorkflowDAO.cancelWorkflows(ctx, workflowIds)); } public void resumeWorkflows(List workflowIds, String queueName) { - dbRetry(() -> WorkflowDAO.resumeWorkflows(dataSource, schema, workflowIds, queueName)); + dbRetry(() -> WorkflowDAO.resumeWorkflows(ctx, workflowIds, queueName)); } public void deleteWorkflows(List workflowIds, boolean deleteChildren) { - dbRetry(() -> WorkflowDAO.deleteWorkflows(dataSource, schema, workflowIds, deleteChildren)); + dbRetry(() -> WorkflowDAO.deleteWorkflows(ctx, workflowIds, deleteChildren)); } public String forkWorkflow(String originalWorkflowId, int startStep, ForkOptions options) { - return dbRetry( - () -> - WorkflowDAO.forkWorkflow( - dataSource, schema, serializer, originalWorkflowId, startStep, options)); + return dbRetry(() -> WorkflowDAO.forkWorkflow(ctx, originalWorkflowId, startStep, options)); } public void createApplicationVersion(String versionName) { - dbRetry(() -> ApplicationVersionDAO.createApplicationVersion(dataSource, schema, versionName)); + dbRetry(() -> ApplicationVersionDAO.createApplicationVersion(ctx, versionName)); } public void updateApplicationVersionTimestamp(String versionName, Instant newTimestamp) { dbRetry( () -> ApplicationVersionDAO.updateApplicationVersionTimestamp( - dataSource, schema, versionName, newTimestamp)); + ctx, versionName, newTimestamp)); } public List listApplicationVersions() { - return dbRetry(() -> ApplicationVersionDAO.listApplicationVersions(dataSource, schema)); + return dbRetry(() -> ApplicationVersionDAO.listApplicationVersions(ctx)); } public VersionInfo getLatestApplicationVersion() { - return dbRetry(() -> ApplicationVersionDAO.getLatestApplicationVersion(dataSource, schema)); + return dbRetry(() -> ApplicationVersionDAO.getLatestApplicationVersion(ctx)); } public void garbageCollect(Instant cutoff, Long rowsThreshold) { - dbRetry(() -> WorkflowDAO.garbageCollect(dataSource, schema, cutoff, rowsThreshold)); + dbRetry(() -> WorkflowDAO.garbageCollect(ctx, cutoff, rowsThreshold)); } public void setWorkflowDelay(String workflowId, WorkflowDelay delay) { - dbRetry(() -> WorkflowDAO.setWorkflowDelay(dataSource, schema, workflowId, delay)); + dbRetry(() -> WorkflowDAO.setWorkflowDelay(ctx, workflowId, delay)); } public void transitionDelayedWorkflows() { - dbRetry(() -> WorkflowDAO.transitionDelayedWorkflows(dataSource, schema)); + dbRetry(() -> WorkflowDAO.transitionDelayedWorkflows(ctx)); } public void createSchedule(WorkflowSchedule schedule) { - dbRetry(() -> SchedulesDAO.createSchedule(dataSource, schema, serializer, schedule)); + dbRetry(() -> SchedulesDAO.createSchedule(ctx, schedule)); } public Optional getSchedule(String name) { - return dbRetry(() -> SchedulesDAO.getSchedule(dataSource, schema, serializer, name)); + return dbRetry(() -> SchedulesDAO.getSchedule(ctx, name)); } public List listSchedules( @@ -534,74 +481,67 @@ public List listSchedules( List workflowNames, List scheduleNamePrefixes) { return dbRetry( - () -> - SchedulesDAO.listSchedules( - dataSource, schema, serializer, statuses, workflowNames, scheduleNamePrefixes)); + () -> SchedulesDAO.listSchedules(ctx, statuses, workflowNames, scheduleNamePrefixes)); } public void pauseSchedule(String name) { - dbRetry(() -> SchedulesDAO.pauseSchedule(dataSource, schema, name)); + dbRetry(() -> SchedulesDAO.pauseSchedule(ctx, name)); } public void resumeSchedule(String name) { - dbRetry(() -> SchedulesDAO.resumeSchedule(dataSource, schema, name)); + dbRetry(() -> SchedulesDAO.resumeSchedule(ctx, name)); } public void updateScheduleLastFiredAt(String name, Instant lastFiredAt) { - dbRetry(() -> SchedulesDAO.updateScheduleLastFiredAt(dataSource, schema, name, lastFiredAt)); + dbRetry(() -> SchedulesDAO.updateScheduleLastFiredAt(ctx, name, lastFiredAt)); } public void deleteSchedule(String name) { - dbRetry(() -> SchedulesDAO.deleteSchedule(dataSource, schema, name)); + dbRetry(() -> SchedulesDAO.deleteSchedule(ctx, name)); } public void applySchedules(List schedules) { - dbRetry(() -> SchedulesDAO.applySchedules(dataSource, schema, serializer, schedules)); + dbRetry(() -> SchedulesDAO.applySchedules(ctx, schedules)); } public Optional getExternalState(String service, String workflowName, String key) { - return dbRetry( - () -> ExternalStateDAO.getExternalState(dataSource, schema, service, workflowName, key)); + return dbRetry(() -> ExternalStateDAO.getExternalState(ctx, service, workflowName, key)); } public ExternalState upsertExternalState(ExternalState state) { - return dbRetry(() -> ExternalStateDAO.upsertExternalState(dataSource, schema, state)); + return dbRetry(() -> ExternalStateDAO.upsertExternalState(ctx, state)); } public List getMetrics(Instant startTime, Instant endTime) { - return dbRetry(() -> WorkflowDAO.getMetrics(dataSource, schema, startTime, endTime)); + return dbRetry(() -> WorkflowDAO.getMetrics(ctx, startTime, endTime)); } public boolean patch(String workflowId, int functionId, String patchName) { - return dbRetry(() -> StepsDAO.patch(dataSource, workflowId, functionId, patchName, schema)); + return dbRetry(() -> StepsDAO.patch(ctx, workflowId, functionId, patchName)); } public boolean deprecatePatch(String workflowId, int functionId, String patchName) { - return dbRetry( - () -> StepsDAO.deprecatePatch(dataSource, workflowId, functionId, patchName, schema)); + return dbRetry(() -> StepsDAO.deprecatePatch(ctx, workflowId, functionId, patchName)); } public Set getWorkflowChildren(String workflowId) { - return dbRetry(() -> WorkflowDAO.getWorkflowChildren(dataSource, schema, workflowId)); + return dbRetry(() -> WorkflowDAO.getWorkflowChildren(ctx, workflowId)); } public Map getAllEvents(String workflowId) { - return dbRetry(() -> WorkflowDAO.getAllEvents(dataSource, schema, serializer, workflowId)); + return dbRetry(() -> WorkflowDAO.getAllEvents(ctx, workflowId)); } public List getAllNotifications(String workflowId) { - return dbRetry( - () -> NotificationsDAO.getAllNotifications(dataSource, schema, serializer, workflowId)); + return dbRetry(() -> NotificationsDAO.getAllNotifications(ctx, workflowId)); } public List exportWorkflow(String workflowId, boolean exportChildren) { - return dbRetry( - () -> - WorkflowDAO.exportWorkflow(dataSource, schema, serializer, workflowId, exportChildren)); + return dbRetry(() -> WorkflowDAO.exportWorkflow(ctx, workflowId, exportChildren)); } public void importWorkflow(List workflows) { - dbRetry(() -> WorkflowDAO.importWorkflow(dataSource, schema, serializer, workflows)); + dbRetry(() -> WorkflowDAO.importWorkflow(ctx, workflows)); } public void writeStreamFromStep( @@ -609,7 +549,7 @@ public void writeStreamFromStep( dbRetry( () -> StreamsDAO.writeStreamFromStep( - dataSource, schema, workflowId, functionId, key, value, serializationFormat)); + ctx, workflowId, functionId, key, value, serializationFormat)); } public void writeStreamFromWorkflow( @@ -617,18 +557,18 @@ public void writeStreamFromWorkflow( dbRetry( () -> StreamsDAO.writeStreamFromWorkflow( - dataSource, schema, workflowId, functionId, key, value, serializationFormat)); + ctx, workflowId, functionId, key, value, serializationFormat)); } public void closeStream(String workflowId, int functionId, String key) { - dbRetry(() -> StreamsDAO.closeStream(dataSource, schema, workflowId, functionId, key)); + dbRetry(() -> StreamsDAO.closeStream(ctx, workflowId, functionId, key)); } public Object readStream(String workflowId, String key, int offset) { - return dbRetry(() -> StreamsDAO.readStream(dataSource, schema, workflowId, key, offset)); + return dbRetry(() -> StreamsDAO.readStream(ctx, workflowId, key, offset)); } public Map> getAllStreamEntries(String workflowId) { - return dbRetry(() -> StreamsDAO.getAllStreamEntries(dataSource, schema, workflowId)); + return dbRetry(() -> StreamsDAO.getAllStreamEntries(ctx, workflowId)); } } diff --git a/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java b/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java index 4119571a..d9fe41ad 100644 --- a/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java @@ -46,8 +46,6 @@ import java.util.UUID; import java.util.stream.Stream; -import javax.sql.DataSource; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -72,9 +70,7 @@ class WorkflowDAO { private WorkflowDAO() {} static WorkflowInitResult initWorkflowStatus( - DataSource dataSource, - String schema, - DBOSSerializer serializer, + DbContext ctx, WorkflowStatusInternal initStatus, Integer maxRetries, boolean isRecoveryRequest, @@ -84,17 +80,17 @@ static WorkflowInitResult initWorkflowStatus( logger.debug("initWorkflowStatus workflowId {}", initStatus.workflowId()); - try (Connection connection = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { boolean shouldCommit = false; try { - connection.setAutoCommit(false); - connection.setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED); + conn.setAutoCommit(false); + conn.setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED); InsertWorkflowResult resRow = insertWorkflowStatus( - connection, schema, initStatus, ownerXid, isRecoveryRequest || isDequeuedRequest); + conn, ctx.schema(), initStatus, ownerXid, isRecoveryRequest || isDequeuedRequest); if (!Objects.equals(resRow.workflowName(), initStatus.workflowName())) { String msg = @@ -142,9 +138,9 @@ static WorkflowInitResult initWorkflowStatus( SET status = ?, deduplication_id = NULL, started_at_epoch_ms = NULL, queue_name = NULL WHERE workflow_uuid = ? AND status = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (PreparedStatement stmt = conn.prepareStatement(sql)) { stmt.setString(1, WorkflowState.MAX_RECOVERY_ATTEMPTS_EXCEEDED.name()); stmt.setString(2, initStatus.workflowId()); stmt.setString(3, WorkflowState.PENDING.name()); @@ -160,9 +156,9 @@ static WorkflowInitResult initWorkflowStatus( } finally { if (shouldCommit) { - connection.commit(); + conn.commit(); } else { - connection.rollback(); + conn.rollback(); } DebugTriggers.debugTriggerPoint(DebugTriggers.DEBUG_TRIGGER_INITWF_COMMIT); } @@ -188,7 +184,7 @@ static record InsertWorkflowResult( * @throws SQLException */ static InsertWorkflowResult insertWorkflowStatus( - Connection connection, + Connection conn, String schema, WorkflowStatusInternal status, String ownerXid, @@ -239,7 +235,7 @@ ON CONFLICT (workflow_uuid) status.authenticatedRoles() != null ? JsonUtility.toJson(status.authenticatedRoles()) : null; - try (PreparedStatement stmt = connection.prepareStatement(insertSQL)) { + try (var stmt = conn.prepareStatement(insertSQL)) { var now = System.currentTimeMillis(); stmt.setString(1, status.workflowId()); @@ -310,7 +306,7 @@ ON CONFLICT (workflow_uuid) } static void updateWorkflowOutcome( - Connection connection, + Connection conn, String schema, String workflowId, WorkflowState status, @@ -329,7 +325,7 @@ static void updateWorkflowOutcome( """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (var stmt = conn.prepareStatement(sql)) { stmt.setString(1, status.name()); stmt.setString(2, output); stmt.setString(3, error); @@ -350,11 +346,11 @@ static void updateWorkflowOutcome( * @param workflowId id of the workflow * @param result output serialized as json */ - static void recordWorkflowOutput( - DataSource dataSource, String schema, String workflowId, String result) throws SQLException { + static void recordWorkflowOutput(DbContext ctx, String workflowId, String result) + throws SQLException { - try (Connection connection = dataSource.getConnection()) { - updateWorkflowOutcome(connection, schema, workflowId, WorkflowState.SUCCESS, result, null); + try (var conn = ctx.getConnection()) { + updateWorkflowOutcome(conn, ctx.schema(), workflowId, WorkflowState.SUCCESS, result, null); } } @@ -364,20 +360,19 @@ static void recordWorkflowOutput( * @param workflowId id of the workflow * @param error output serialized as json */ - static void recordWorkflowError( - DataSource dataSource, String schema, String workflowId, String error) throws SQLException { + static void recordWorkflowError(DbContext ctx, String workflowId, String error) + throws SQLException { - try (Connection connection = dataSource.getConnection()) { - updateWorkflowOutcome(connection, schema, workflowId, WorkflowState.ERROR, null, error); + try (var conn = ctx.getConnection()) { + updateWorkflowOutcome(conn, ctx.schema(), workflowId, WorkflowState.ERROR, null, error); } } - static String getWorkflowSerialization(DataSource dataSource, String schema, String workflowId) - throws SQLException { + static String getWorkflowSerialization(DbContext ctx, String workflowId) throws SQLException { var sql = "SELECT serialization FROM \"%s\".workflow_status WHERE workflow_uuid = ?" - .formatted(schema); - try (var conn = dataSource.getConnection(); + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); try (var rs = stmt.executeQuery()) { @@ -389,12 +384,10 @@ static String getWorkflowSerialization(DataSource dataSource, String schema, Str return null; } - static WorkflowStatus getWorkflowStatus( - DataSource dataSource, String schema, DBOSSerializer serializer, String workflowId) - throws SQLException { + static WorkflowStatus getWorkflowStatus(DbContext ctx, String workflowId) throws SQLException { - try (var conn = dataSource.getConnection()) { - return getWorkflowStatus(conn, schema, serializer, workflowId); + try (var conn = ctx.getConnection()) { + return getWorkflowStatus(conn, ctx.schema(), ctx.serializer(), workflowId); } } @@ -421,8 +414,7 @@ static WorkflowStatus getWorkflowStatus( return null; } - static void setWorkflowDelay( - DataSource dataSource, String schema, String workflowId, WorkflowDelay delay) + static void setWorkflowDelay(DbContext ctx, String workflowId, WorkflowDelay delay) throws SQLException { Objects.requireNonNull(workflowId, "workflowId must not be null"); Objects.requireNonNull(delay, "delay must not be null"); @@ -446,8 +438,8 @@ static void setWorkflowDelay( WHERE workflow_uuid = ? AND status = ? """ - .formatted(schema); - try (var conn = dataSource.getConnection(); + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setLong(1, resolved.toEpochMilli()); stmt.setString(2, workflowId); @@ -457,7 +449,7 @@ static void setWorkflowDelay( } } - static void transitionDelayedWorkflows(DataSource dataSource, String schema) throws SQLException { + static void transitionDelayedWorkflows(DbContext ctx) throws SQLException { var sql = """ UPDATE "%s".workflow_status @@ -465,9 +457,9 @@ static void transitionDelayedWorkflows(DataSource dataSource, String schema) thr WHERE status = ? AND delay_until_epoch_ms <= ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { stmt.setString(1, WorkflowState.ENQUEUED.name()); stmt.setString(2, WorkflowState.DELAYED.name()); @@ -477,10 +469,10 @@ static void transitionDelayedWorkflows(DataSource dataSource, String schema) thr } } - static List listWorkflows( - DataSource dataSource, String schema, DBOSSerializer serializer, ListWorkflowsInput input) + static List listWorkflows(DbContext ctx, ListWorkflowsInput input) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); if (input == null) { input = new ListWorkflowsInput(); } @@ -504,7 +496,7 @@ static List listWorkflows( sqlBuilder.append(", serialization"); } - sqlBuilder.append(" FROM \"%s\".workflow_status ".formatted(schema)); + sqlBuilder.append(" FROM \"%s\".workflow_status ".formatted(ctx.schema())); // --- WHERE Clauses --- StringJoiner whereConditions = new StringJoiner(" AND "); @@ -616,7 +608,7 @@ static List listWorkflows( parameters.add(input.offset()); } - try (Connection connection = dataSource.getConnection(); + try (Connection connection = ctx.getConnection(); PreparedStatement pstmt = connection.prepareStatement(sqlBuilder.toString())) { List arrays = new ArrayList<>(); try { @@ -654,11 +646,8 @@ static List listWorkflows( } static List getWorkflowAggregates( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - GetWorkflowAggregatesInput input) - throws SQLException { + DbContext ctx, GetWorkflowAggregatesInput input) throws SQLException { + if (input == null) { input = new GetWorkflowAggregatesInput(); } @@ -684,7 +673,7 @@ record GroupDim(String name, String column) {} StringJoiner selectCols = new StringJoiner(", "); for (var dim : dims) selectCols.add(dim.column()); selectCols.add("COUNT(*) AS count"); - sqlBuilder.append(selectCols).append(" FROM \"%s\".workflow_status".formatted(schema)); + sqlBuilder.append(selectCols).append(" FROM \"%s\".workflow_status".formatted(ctx.schema())); // --- WHERE --- StringJoiner whereConditions = new StringJoiner(" AND "); @@ -738,7 +727,7 @@ record GroupDim(String name, String column) {} sqlBuilder.append(" ORDER BY ").append(groupByCols); List results = new ArrayList<>(); - try (Connection connection = dataSource.getConnection(); + try (Connection connection = ctx.getConnection(); PreparedStatement pstmt = connection.prepareStatement(sqlBuilder.toString())) { List arrays = new ArrayList<>(); try { @@ -827,39 +816,31 @@ private static WorkflowStatus resultsToWorkflowStatus( } static List getPendingWorkflows( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - List executorIds, - String appVersion) - throws SQLException { + DbContext ctx, List executorIds, String appVersion) throws SQLException { var input = new ListWorkflowsInput() .withStatus(WorkflowState.PENDING) .withExecutorIds(executorIds) .withApplicationVersion(appVersion); - return listWorkflows(dataSource, schema, serializer, input); + return listWorkflows(ctx, input); } @SuppressWarnings("unchecked") static Result awaitWorkflowResult( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - Duration dbPollingInterval, - String workflowId) - throws SQLException { + DbContext ctx, Duration dbPollingInterval, String workflowId) throws SQLException { + DBOSSerializer serializer = ctx.serializer(); final String sql = """ SELECT status, output, error, serialization FROM "%s".workflow_status WHERE workflow_uuid = ? """ - .formatted(schema); + .formatted(ctx.schema()); while (true) { - try (Connection connection = dataSource.getConnection(); + if (ctx.isClosed()) throw new IllegalStateException("SystemDatabase is closed"); + try (Connection connection = ctx.getConnection(); PreparedStatement stmt = connection.prepareStatement(sql)) { stmt.setString(1, workflowId); @@ -902,8 +883,7 @@ static Result awaitWorkflowResult( } static void recordChildWorkflow( - DataSource dataSource, - String schema, + DbContext ctx, String parentId, String childId, // workflowId of the child int functionId, // func id in the parent @@ -914,22 +894,21 @@ static void recordChildWorkflow( var result = new StepResult(parentId, functionId, functionName, null, null, null, null) .withChildWorkflowId(childId); - try (Connection connection = dataSource.getConnection()) { - StepsDAO.recordStepResultTxn(result, null, null, connection, schema); + try (var conn = ctx.getConnection()) { + StepsDAO.recordStepResultTxn(conn, ctx.schema(), result, null, null); } } - static Optional checkChildWorkflow( - DataSource dataSource, String schema, String workflowUuid, int functionId) + static Optional checkChildWorkflow(DbContext ctx, String workflowUuid, int functionId) throws SQLException { final String sql = """ SELECT child_workflow_id FROM "%s".operation_outputs WHERE workflow_uuid = ? AND function_id = ? """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection connection = dataSource.getConnection(); + try (Connection connection = ctx.getConnection(); PreparedStatement stmt = connection.prepareStatement(sql)) { stmt.setString(1, workflowUuid); @@ -952,8 +931,7 @@ private static List filterNullsAndBlanks(List workflowIds) { return workflowIds.stream().filter(id -> id != null && !id.isBlank()).toList(); } - static void cancelWorkflows(DataSource dataSource, String schema, List workflowIds) - throws SQLException { + static void cancelWorkflows(DbContext ctx, List workflowIds) throws SQLException { List filtered = filterNullsAndBlanks(workflowIds); if (filtered.isEmpty()) { return; @@ -969,9 +947,9 @@ static void cancelWorkflows(DataSource dataSource, String schema, List w WHERE workflow_uuid = ANY(?) AND status NOT IN (?, ?) """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement stmt = conn.prepareStatement(sql)) { Array array = conn.createArrayOf("text", filtered.toArray(String[]::new)); try { @@ -986,8 +964,7 @@ AND status NOT IN (?, ?) } } - static void resumeWorkflows( - DataSource dataSource, String schema, List workflowIds, String queueName) + static void resumeWorkflows(DbContext ctx, List workflowIds, String queueName) throws SQLException { List filtered = filterNullsAndBlanks(workflowIds); if (filtered.isEmpty()) { @@ -1007,9 +984,9 @@ static void resumeWorkflows( WHERE workflow_uuid = ANY(?) AND status NOT IN (?, ?) """ - .formatted(schema); + .formatted(ctx.schema()); - try (Connection conn = dataSource.getConnection(); + try (Connection conn = ctx.getConnection(); PreparedStatement stmt = conn.prepareStatement(sql)) { Array array = conn.createArrayOf("text", filtered.toArray(String[]::new)); try { @@ -1025,8 +1002,7 @@ AND status NOT IN (?, ?) } } - static void deleteWorkflows( - DataSource dataSource, String schema, List workflowIds, boolean deleteChildren) + static void deleteWorkflows(DbContext ctx, List workflowIds, boolean deleteChildren) throws SQLException { List filtered = filterNullsAndBlanks(workflowIds); if (filtered.isEmpty()) { @@ -1036,7 +1012,7 @@ static void deleteWorkflows( var wfIdSet = new HashSet(filtered); if (deleteChildren) { for (var wfid : filtered) { - var children = getWorkflowChildren(dataSource, schema, wfid); + var children = getWorkflowChildren(ctx, wfid); wfIdSet.addAll(children); } } @@ -1046,9 +1022,9 @@ static void deleteWorkflows( DELETE FROM "%s".workflow_status WHERE workflow_uuid = ANY(?); """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { var array = conn.createArrayOf("text", wfIdSet.toArray(String[]::new)); try { @@ -1060,8 +1036,7 @@ static void deleteWorkflows( } } - static Set getWorkflowChildren(DataSource dataSource, String schema, String workflowId) - throws SQLException { + static Set getWorkflowChildren(DbContext ctx, String workflowId) throws SQLException { var children = new HashSet(); var toProcess = new ArrayDeque(); toProcess.add(workflowId); @@ -1072,9 +1047,9 @@ static Set getWorkflowChildren(DataSource dataSource, String schema, Str FROM "%s".operation_outputs WHERE workflow_uuid = ? AND child_workflow_id IS NOT NULL """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var stmt = conn.prepareStatement(sql)) { while (!toProcess.isEmpty()) { var wfid = toProcess.poll(); @@ -1095,17 +1070,12 @@ static Set getWorkflowChildren(DataSource dataSource, String schema, Str } static String forkWorkflow( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - String originalWorkflowId, - int startStep, - ForkOptions options) + DbContext ctx, String originalWorkflowId, int startStep, ForkOptions options) throws SQLException { options = Objects.requireNonNullElseGet(options, ForkOptions::new); - var status = getWorkflowStatus(dataSource, schema, serializer, originalWorkflowId); + var status = getWorkflowStatus(ctx, originalWorkflowId); if (status == null) { throw new DBOSNonExistentWorkflowException(originalWorkflowId); } @@ -1124,52 +1094,52 @@ static String forkWorkflow( timeoutMS = explicit.value().toMillis(); } - try (Connection connection = dataSource.getConnection()) { - connection.setAutoCommit(false); + try (var conn = ctx.getConnection()) { + conn.setAutoCommit(false); try { // Create entry for forked workflow insertForkedWorkflowStatus( - connection, + conn, + ctx.schema(), + ctx.serializer(), originalWorkflowId, forkedWorkflowId, status, options.applicationVersion(), timeoutMS, options.queueName(), - options.queuePartitionKey(), - schema, - serializer); + options.queuePartitionKey()); // Copy operation outputs if starting from step > 0 if (startStep > 0) { - copyOperationOutputs(connection, originalWorkflowId, forkedWorkflowId, startStep, schema); + copyOperationOutputs(conn, ctx.schema(), originalWorkflowId, forkedWorkflowId, startStep); } // Mark the original workflow as having been forked - markWasForkedFrom(connection, originalWorkflowId, schema); + markWasForkedFrom(conn, ctx.schema(), originalWorkflowId); - connection.commit(); + conn.commit(); return forkedWorkflowId; } catch (SQLException e) { - connection.rollback(); + conn.rollback(); throw e; } } } private static void insertForkedWorkflowStatus( - Connection connection, + Connection conn, + String schema, + DBOSSerializer serializer, String originalWorkflowId, String forkedWorkflowId, WorkflowStatus originalStatus, String applicationVersion, Long timeoutMS, String queueName, - String queuePartitionKey, - String schema, - DBOSSerializer serializer) + String queuePartitionKey) throws SQLException { Objects.requireNonNull(schema); @@ -1183,7 +1153,7 @@ private static void insertForkedWorkflowStatus( """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (var stmt = conn.prepareStatement(sql)) { stmt.setString(1, forkedWorkflowId); stmt.setString(2, WorkflowState.ENQUEUED.name()); stmt.setString(3, originalStatus.workflowName()); @@ -1213,7 +1183,7 @@ private static void insertForkedWorkflowStatus( } } - private static void markWasForkedFrom(Connection connection, String workflowId, String schema) + private static void markWasForkedFrom(Connection conn, String schema, String workflowId) throws SQLException { String sql = """ @@ -1222,18 +1192,18 @@ private static void markWasForkedFrom(Connection connection, String workflowId, WHERE workflow_uuid = ? """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (var stmt = conn.prepareStatement(sql)) { stmt.setString(1, workflowId); stmt.executeUpdate(); } } private static void copyOperationOutputs( - Connection connection, + Connection conn, + String schema, String originalWorkflowId, String forkedWorkflowId, - int startStep, - String schema) + int startStep) throws SQLException { String stepOutputsSql = @@ -1245,7 +1215,7 @@ private static void copyOperationOutputs( WHERE workflow_uuid = ? AND function_id < ? """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(stepOutputsSql)) { + try (var stmt = conn.prepareStatement(stepOutputsSql)) { stmt.setString(1, forkedWorkflowId); stmt.setString(2, originalWorkflowId); stmt.setInt(3, startStep); @@ -1263,7 +1233,7 @@ private static void copyOperationOutputs( WHERE workflow_uuid = ? AND function_id < ? """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(eventHistorySql)) { + try (var stmt = conn.prepareStatement(eventHistorySql)) { stmt.setString(1, forkedWorkflowId); stmt.setString(2, originalWorkflowId); stmt.setInt(3, startStep); @@ -1289,7 +1259,7 @@ SELECT MAX(weh2.function_id) """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(eventSql)) { + try (var stmt = conn.prepareStatement(eventSql)) { stmt.setString(1, forkedWorkflowId); stmt.setString(2, originalWorkflowId); stmt.setString(3, originalWorkflowId); @@ -1308,7 +1278,7 @@ SELECT MAX(weh2.function_id) WHERE workflow_uuid = ? AND function_id < ? """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(streamsSql)) { + try (var stmt = conn.prepareStatement(streamsSql)) { stmt.setString(1, forkedWorkflowId); stmt.setString(2, originalWorkflowId); stmt.setInt(3, startStep); @@ -1318,14 +1288,14 @@ SELECT MAX(weh2.function_id) } } - private static Instant getRowsCutoff(Connection connection, long rowsThreshold, String schema) + private static Instant getRowsCutoff(Connection conn, String schema, long rowsThreshold) throws SQLException { String sql = """ SELECT created_at FROM "%s".workflow_status ORDER BY created_at DESC OFFSET ? LIMIT 1 """ .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + try (var stmt = conn.prepareStatement(sql)) { stmt.setLong(1, rowsThreshold - 1); try (ResultSet rs = stmt.executeQuery()) { if (rs.next()) { @@ -1337,13 +1307,12 @@ private static Instant getRowsCutoff(Connection connection, long rowsThreshold, return null; } - static void garbageCollect( - DataSource dataSource, String schema, Instant cutoff, Long rowsThreshold) + static void garbageCollect(DbContext ctx, Instant cutoff, Long rowsThreshold) throws SQLException { - try (Connection connection = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { if (rowsThreshold != null) { - var rowsCutoff = getRowsCutoff(connection, rowsThreshold, schema); + var rowsCutoff = getRowsCutoff(conn, ctx.schema(), rowsThreshold); if (rowsCutoff != null) { if (cutoff == null || rowsCutoff.isAfter(cutoff)) { cutoff = rowsCutoff; @@ -1356,8 +1325,8 @@ static void garbageCollect( """ DELETE FROM "%s".workflow_status WHERE created_at < ? AND status NOT IN (?, ?, ?) """ - .formatted(schema); - try (PreparedStatement stmt = connection.prepareStatement(sql)) { + .formatted(ctx.schema()); + try (var stmt = conn.prepareStatement(sql)) { stmt.setLong(1, cutoff.toEpochMilli()); stmt.setString(2, WorkflowState.PENDING.name()); stmt.setString(3, WorkflowState.ENQUEUED.name()); @@ -1369,8 +1338,7 @@ static void garbageCollect( } } - static List getMetrics( - DataSource dataSource, String schema, Instant startTime, Instant endTime) + static List getMetrics(DbContext ctx, Instant startTime, Instant endTime) throws SQLException { final var start = Objects.requireNonNull(startTime).toEpochMilli(); final var end = Objects.requireNonNull(endTime).toEpochMilli(); @@ -1383,7 +1351,7 @@ SELECT name, COUNT(workflow_uuid) as count WHERE created_at >= ? AND created_at < ? GROUP BY name """ - .formatted(schema); + .formatted(ctx.schema()); final var stepSQL = """ SELECT function_name, COUNT(*) as count @@ -1391,9 +1359,9 @@ SELECT function_name, COUNT(*) as count WHERE completed_at_epoch_ms >= ? AND completed_at_epoch_ms < ? GROUP BY function_name """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection(); + try (var conn = ctx.getConnection(); var ps1 = conn.prepareStatement(wfSQL); var ps2 = conn.prepareStatement(stepSQL)) { @@ -1502,40 +1470,34 @@ static List listWorkflowStreams(Connection conn, String schema, } static List exportWorkflow( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - String workflowId, - boolean exportChildren) - throws SQLException { + DbContext ctx, String workflowId, boolean exportChildren) throws SQLException { + var workflowIds = exportChildren ? Stream.concat( - getWorkflowChildren(dataSource, schema, workflowId).stream(), - List.of(workflowId).stream()) + getWorkflowChildren(ctx, workflowId).stream(), List.of(workflowId).stream()) .toList() : List.of(workflowId); var workflows = new ArrayList(); for (var wfid : workflowIds) { - try (var conn = dataSource.getConnection()) { - var status = getWorkflowStatus(conn, schema, serializer, wfid); - var steps = StepsDAO.listWorkflowSteps(conn, wfid, true, null, null, schema, serializer); - var events = listWorkflowEvents(conn, schema, wfid); - var eventHistory = listWorkflowEventHistory(conn, schema, wfid); - var streams = listWorkflowStreams(conn, schema, wfid); + try (var conn = ctx.getConnection()) { + var status = getWorkflowStatus(conn, ctx.schema(), ctx.serializer(), wfid); + var steps = + StepsDAO.listWorkflowSteps( + conn, ctx.schema(), ctx.serializer(), wfid, true, null, null); + var events = listWorkflowEvents(conn, ctx.schema(), wfid); + var eventHistory = listWorkflowEventHistory(conn, ctx.schema(), wfid); + var streams = listWorkflowStreams(conn, ctx.schema(), wfid); workflows.add(new ExportedWorkflow(status, steps, events, eventHistory, streams)); } } return workflows; } - static void importWorkflow( - DataSource dataSource, - String schema, - DBOSSerializer serializer, - List workflows) - throws SQLException { + static void importWorkflow(DbContext ctx, List workflows) throws SQLException { + + DBOSSerializer serializer = ctx.serializer(); var wfSQL = """ INSERT INTO "%s".workflow_status ( @@ -1553,7 +1515,7 @@ static void importWorkflow( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) """ - .formatted(schema); + .formatted(ctx.schema()); var stepSQL = """ @@ -1566,7 +1528,7 @@ static void importWorkflow( ?, ?, ?, ?, ?, ?, ?, ?, ? ) """ - .formatted(schema); + .formatted(ctx.schema()); var eventSQL = """ @@ -1576,7 +1538,7 @@ static void importWorkflow( ?, ?, ?, ? ) """ - .formatted(schema); + .formatted(ctx.schema()); var eventHistorySQL = """ @@ -1586,7 +1548,7 @@ static void importWorkflow( ?, ?, ?, ?, ? ) """ - .formatted(schema); + .formatted(ctx.schema()); var streamsSQL = """ @@ -1596,9 +1558,9 @@ static void importWorkflow( ?, ?, ?, ?, ?, ? ) """ - .formatted(schema); + .formatted(ctx.schema()); - try (var conn = dataSource.getConnection()) { + try (var conn = ctx.getConnection()) { conn.setAutoCommit(false); try (var wfStmt = conn.prepareStatement(wfSQL); @@ -1723,16 +1685,15 @@ static void importWorkflow( } } - static Map getAllEvents( - DataSource dataSource, String schema, DBOSSerializer serializer, String workflowId) - throws SQLException { - try (var conn = dataSource.getConnection()) { - var events = listWorkflowEvents(conn, schema, workflowId); + static Map getAllEvents(DbContext ctx, String workflowId) throws SQLException { + try (var conn = ctx.getConnection()) { + var events = listWorkflowEvents(conn, ctx.schema(), workflowId); var result = new LinkedHashMap(); for (var event : events) { result.put( event.key(), - SerializationUtil.deserializeValue(event.value(), event.serialization(), serializer)); + SerializationUtil.deserializeValue( + event.value(), event.serialization(), ctx.serializer())); } return result; } From 9e3af825a85dc6222b5d6f686348096dc942d157 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Thu, 14 May 2026 15:29:54 -0700 Subject: [PATCH 08/27] WakeReason --- .../transact/database/SignalRegistry.java | 60 ++- .../transact/database/SignalRegistryTest.java | 371 ++++++++++-------- 2 files changed, 252 insertions(+), 179 deletions(-) diff --git a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java index b035771a..190f910f 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java +++ b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java @@ -1,29 +1,59 @@ package dev.dbos.transact.database; +import java.time.Duration; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; sealed interface SignalKey permits SignalKey.Cancellation, SignalKey.Event, SignalKey.Message, SignalKey.Shutdown { - record Cancellation(String workflowId) implements SignalKey {} + WakeReason wakeReason(); - record Event(String workflowId, String topic) implements SignalKey {} + record Cancellation(String workflowId) implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.CANCELLED; + } + } + + record Event(String workflowId, String topic) implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.EVENT; + } + } + + record Message(String workflowId, String topic) implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.MESSAGE; + } + } - record Message(String workflowId, String topic) implements SignalKey {} + record Shutdown() implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.SHUTDOWN; + } + } +} - record Shutdown() implements SignalKey {} +enum WakeReason { + MESSAGE, + EVENT, + CANCELLED, + SHUTDOWN, + TIMEOUT } class SignalRegistry { private static class Entry { - final CompletableFuture future = new CompletableFuture<>(); + final CompletableFuture future = new CompletableFuture<>(); final AtomicInteger refs = new AtomicInteger(1); } - static class Subscription extends CompletableFuture implements AutoCloseable { + static class Subscription extends CompletableFuture implements AutoCloseable { private final Runnable onClose; Subscription(Runnable onClose) { @@ -58,13 +88,13 @@ public Subscription subscribe(SignalKey key) { if (e != null && e.refs.decrementAndGet() == 0) return null; return e; })); - entry.future.thenRun(() -> sub.complete(null)); + entry.future.thenAccept(sub::complete); return sub; } public void signal(SignalKey key) { Entry e = map.remove(key); - if (e != null) e.future.complete(null); + if (e != null) e.future.complete(key.wakeReason()); } Iterable keys() { @@ -74,4 +104,18 @@ Iterable keys() { static Subscription never() { return new Subscription(() -> {}); } + + static WakeReason awaitAny(Duration timeout, Subscription... subscriptions) { + try { + return (WakeReason) + CompletableFuture.anyOf(subscriptions).get(timeout.toMillis(), TimeUnit.MILLISECONDS); + } catch (TimeoutException ignored) { + return WakeReason.TIMEOUT; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } } diff --git a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java b/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java index 9a50dbb2..10d5e865 100644 --- a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java +++ b/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java @@ -3,6 +3,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -18,7 +19,6 @@ class SignalRegistryTest { SignalRegistry registry; - // Reusable keys static final SignalKey KEY = new SignalKey.Cancellation("wf-1"); static final SignalKey KEY_A = new SignalKey.Cancellation("wf-a"); static final SignalKey KEY_B = new SignalKey.Cancellation("wf-b"); @@ -36,7 +36,6 @@ void testSignalKeyStructuralEquality() { assertEquals(new SignalKey.Cancellation("wf-1"), new SignalKey.Cancellation("wf-1")); assertEquals(new SignalKey.Event("wf-1", "t"), new SignalKey.Event("wf-1", "t")); assertNotEquals(new SignalKey.Cancellation("wf-1"), new SignalKey.Cancellation("wf-2")); - // Same fields, different types — must not be equal (prevents cross-type collisions) assertNotEquals( (SignalKey) new SignalKey.Cancellation("wf-1"), (SignalKey) new SignalKey.Event("wf-1", "wf-1")); @@ -45,19 +44,20 @@ void testSignalKeyStructuralEquality() { // --- Core subscribe / signal behaviour --- @Test - void testBasicSubscribeAndSignal() { - CompletableFuture f = registry.subscribe(KEY); + void testBasicSubscribeAndSignal() throws Exception { + var f = registry.subscribe(KEY); assertFalse(f.isDone()); registry.signal(KEY); assertTrue(f.isDone()); assertFalse(f.isCompletedExceptionally()); + assertEquals(WakeReason.CANCELLED, f.get()); } @Test - void testMultipleListenersOnSameKey() { - CompletableFuture f1 = registry.subscribe(KEY); - CompletableFuture f2 = registry.subscribe(KEY); - CompletableFuture f3 = registry.subscribe(KEY); + void testMultipleListenersOnSameKey() throws Exception { + var f1 = registry.subscribe(KEY); + var f2 = registry.subscribe(KEY); + var f3 = registry.subscribe(KEY); assertFalse(f1.isDone()); assertFalse(f2.isDone()); @@ -65,33 +65,30 @@ void testMultipleListenersOnSameKey() { registry.signal(KEY); - assertTrue(f1.isDone()); - assertTrue(f2.isDone()); - assertTrue(f3.isDone()); - assertFalse(f1.isCompletedExceptionally()); - assertFalse(f2.isCompletedExceptionally()); - assertFalse(f3.isCompletedExceptionally()); + assertEquals(WakeReason.CANCELLED, f1.get()); + assertEquals(WakeReason.CANCELLED, f2.get()); + assertEquals(WakeReason.CANCELLED, f3.get()); } @Test - void testMultipleSubscriptionsInAnyOf() { - CompletableFuture f1 = registry.subscribe(KEY_A); - CompletableFuture f2 = registry.subscribe(KEY_B); + void testMultipleSubscriptionsInAnyOf() throws Exception { + var f1 = registry.subscribe(KEY_A); + var f2 = registry.subscribe(KEY_B); - CompletableFuture anyOf = CompletableFuture.anyOf(f1, f2); + var anyOf = CompletableFuture.anyOf(f1, f2); assertFalse(anyOf.isDone()); registry.signal(KEY_B); - assertTrue(anyOf.isDone()); + assertEquals(WakeReason.CANCELLED, (WakeReason) anyOf.get(1, TimeUnit.SECONDS)); assertTrue(f2.isDone()); assertFalse(f1.isDone()); } @Test void testSignalOnlyWakesMatchingKey() { - CompletableFuture f1 = registry.subscribe(KEY_A); - CompletableFuture f2 = registry.subscribe(KEY_B); + var f1 = registry.subscribe(KEY_A); + var f2 = registry.subscribe(KEY_B); registry.signal(KEY_A); @@ -101,12 +98,11 @@ void testSignalOnlyWakesMatchingKey() { @Test void testDifferentKeyTypesWithSameFieldsDoNotCollide() { - // Event("wf-1", "wf-1") and Cancellation("wf-1") must occupy separate map entries - SignalKey eventKey = new SignalKey.Event("wf-1", "wf-1"); - SignalKey cancellationKey = new SignalKey.Cancellation("wf-1"); + var eventKey = new SignalKey.Event("wf-1", "wf-1"); + var cancellationKey = new SignalKey.Cancellation("wf-1"); - CompletableFuture f1 = registry.subscribe(eventKey); - CompletableFuture f2 = registry.subscribe(cancellationKey); + var f1 = registry.subscribe(eventKey); + var f2 = registry.subscribe(cancellationKey); registry.signal(cancellationKey); @@ -118,7 +114,7 @@ void testDifferentKeyTypesWithSameFieldsDoNotCollide() { void testSignalBeforeSubscribeDoesNotWake() { registry.signal(KEY); - CompletableFuture f = registry.subscribe(KEY); + var f = registry.subscribe(KEY); assertFalse(f.isDone()); registry.signal(KEY); @@ -127,35 +123,50 @@ void testSignalBeforeSubscribeDoesNotWake() { @Test void testSignalIsOneShot() { - CompletableFuture f1 = registry.subscribe(KEY); + var f1 = registry.subscribe(KEY); registry.signal(KEY); assertTrue(f1.isDone()); - CompletableFuture f2 = registry.subscribe(KEY); + var f2 = registry.subscribe(KEY); assertFalse(f2.isDone()); } + @Test + void testWakeReasonByKeyType() throws Exception { + var msgSub = registry.subscribe(new SignalKey.Message("wf-1", "topic")); + var eventSub = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + var cancelSub = registry.subscribe(new SignalKey.Cancellation("wf-1")); + var shutdownSub = registry.subscribe(new SignalKey.Shutdown()); + + registry.signal(new SignalKey.Message("wf-1", "topic")); + registry.signal(new SignalKey.Event("wf-1", "topic")); + registry.signal(new SignalKey.Cancellation("wf-1")); + registry.signal(new SignalKey.Shutdown()); + + assertEquals(WakeReason.MESSAGE, msgSub.get()); + assertEquals(WakeReason.EVENT, eventSub.get()); + assertEquals(WakeReason.CANCELLED, cancelSub.get()); + assertEquals(WakeReason.SHUTDOWN, shutdownSub.get()); + } + // --- Subscription / close --- @Test void testCloseOneSubscriberDoesNotOrphanOthers() throws Exception { - // Closing one subscription on a shared key must not prevent the remaining subscriber - // from being woken when the signal fires (ref-counting behaviour). - SignalRegistry.Subscription sub1 = registry.subscribe(KEY); - SignalRegistry.Subscription sub2 = registry.subscribe(KEY); + var sub1 = registry.subscribe(KEY); + var sub2 = registry.subscribe(KEY); - sub1.close(); // ref count drops to 1 — key must stay in map + sub1.close(); registry.signal(KEY); - assertTrue(sub2.isDone()); - assertFalse(sub2.isCompletedExceptionally()); + assertEquals(WakeReason.CANCELLED, sub2.get()); } @Test void testClosePreventsFutureFromBeingSignalled() throws Exception { - SignalRegistry.Subscription sub = registry.subscribe(KEY); + var sub = registry.subscribe(KEY); sub.close(); - registry.signal(KEY); // no entry in map — should be a no-op + registry.signal(KEY); boolean completed = sub.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); @@ -164,18 +175,60 @@ void testClosePreventsFutureFromBeingSignalled() throws Exception { @Test void testNeverFutureNeverCompletes() throws Exception { - SignalRegistry.Subscription f = SignalRegistry.never(); + var f = SignalRegistry.never(); assertFalse(f.isDone()); boolean completed = f.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); assertFalse(completed); } + // --- keys() --- + + @Test + void testKeysReflectsActiveSubscriptions() { + registry.subscribe(KEY_A); + registry.subscribe(KEY_B); + + var keys = registry.keys(); + assertTrue(iterableContains(keys, KEY_A)); + assertTrue(iterableContains(keys, KEY_B)); + } + + @Test + void testKeysExcludesSignalledKey() { + registry.subscribe(KEY_A); + registry.subscribe(KEY_B); + registry.signal(KEY_A); + + var keys = registry.keys(); + assertFalse(iterableContains(keys, KEY_A)); + assertTrue(iterableContains(keys, KEY_B)); + } + + @Test + void testKeysEmptyWhenNoSubscribers() { + assertFalse(registry.keys().iterator().hasNext()); + } + + @Test + void testKeysExcludesKeyAfterLastSubscriberCloses() { + var sub = registry.subscribe(KEY); + sub.close(); + assertFalse(registry.keys().iterator().hasNext()); + } + + private static boolean iterableContains(Iterable keys, SignalKey target) { + for (SignalKey k : keys) { + if (k.equals(target)) return true; + } + return false; + } + // --- Threading --- @Test void testSubscribeBeforeSignalFromAnotherThread() throws Exception { - CompletableFuture f = registry.subscribe(FOO); + var f = registry.subscribe(FOO); CompletableFuture.runAsync( () -> { @@ -192,14 +245,13 @@ void testSubscribeBeforeSignalFromAnotherThread() throws Exception { @Test void testSignalFromMainAfterBackgroundSubscribes() throws Exception { - CompletableFuture backgroundDone = new CompletableFuture<>(); + var backgroundDone = new CompletableFuture(); CompletableFuture.runAsync( () -> { - CompletableFuture f = registry.subscribe(FOO); + var f = registry.subscribe(FOO); try { - f.get(500, TimeUnit.MILLISECONDS); - backgroundDone.complete(null); + backgroundDone.complete(f.get(500, TimeUnit.MILLISECONDS)); } catch (Exception e) { backgroundDone.completeExceptionally(e); } @@ -209,18 +261,18 @@ void testSignalFromMainAfterBackgroundSubscribes() throws Exception { registry.signal(FOO); assertTimeoutPreemptively(Duration.ofSeconds(1), (Executable) backgroundDone::get); + assertEquals(WakeReason.CANCELLED, backgroundDone.get()); } @Test void testMultipleSubscribersInSeparateThreads() throws Exception { - CompletableFuture done1 = new CompletableFuture<>(); - CompletableFuture done2 = new CompletableFuture<>(); + var done1 = new CompletableFuture(); + var done2 = new CompletableFuture(); CompletableFuture.runAsync( () -> { try { - registry.subscribe(FOO).get(500, TimeUnit.MILLISECONDS); - done1.complete(null); + done1.complete(registry.subscribe(FOO).get(500, TimeUnit.MILLISECONDS)); } catch (Exception e) { done1.completeExceptionally(e); } @@ -228,8 +280,7 @@ void testMultipleSubscribersInSeparateThreads() throws Exception { CompletableFuture.runAsync( () -> { try { - registry.subscribe(FOO).get(500, TimeUnit.MILLISECONDS); - done2.complete(null); + done2.complete(registry.subscribe(FOO).get(500, TimeUnit.MILLISECONDS)); } catch (Exception e) { done2.completeExceptionally(e); } @@ -241,8 +292,8 @@ void testMultipleSubscribersInSeparateThreads() throws Exception { assertTimeoutPreemptively( Duration.ofSeconds(1), () -> { - done1.get(); - done2.get(); + assertEquals(WakeReason.CANCELLED, done1.get()); + assertEquals(WakeReason.CANCELLED, done2.get()); }); } @@ -252,232 +303,210 @@ void testConcurrentSignalAndSubscribe() throws Exception { Duration.ofSeconds(5), () -> { for (int i = 0; i < 1000; i++) { - SignalRegistry r = new SignalRegistry(); - CompletableFuture sub = r.subscribe(KEY); + var r = new SignalRegistry(); + var sub = r.subscribe(KEY); CompletableFuture.runAsync(() -> r.signal(KEY)); sub.get(1, TimeUnit.SECONDS); } }); } - // --- keys() --- + // --- awaitAny --- @Test - void testKeysReflectsActiveSubscriptions() { - registry.subscribe(KEY_A); - registry.subscribe(KEY_B); + void testAwaitAny_notifyFires() throws Exception { + var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + var onCancelled = registry.subscribe(new SignalKey.Cancellation("wf-1")); + var onShutdown = SignalRegistry.never(); - Iterable keys = registry.keys(); - assertTrue(iterableContains(keys, KEY_A)); - assertTrue(iterableContains(keys, KEY_B)); + registry.signal(new SignalKey.Event("wf-1", "topic")); + + assertEquals( + WakeReason.EVENT, + SignalRegistry.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); } @Test - void testKeysExcludesSignalledKey() { - registry.subscribe(KEY_A); - registry.subscribe(KEY_B); - registry.signal(KEY_A); + void testAwaitAny_cancelledFires() throws Exception { + var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + var onCancelled = registry.subscribe(new SignalKey.Cancellation("wf-1")); + var onShutdown = SignalRegistry.never(); - Iterable keys = registry.keys(); - assertFalse(iterableContains(keys, KEY_A)); - assertTrue(iterableContains(keys, KEY_B)); - } + registry.signal(new SignalKey.Cancellation("wf-1")); - @Test - void testKeysEmptyWhenNoSubscribers() { - assertFalse(registry.keys().iterator().hasNext()); + assertEquals( + WakeReason.CANCELLED, + SignalRegistry.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); } @Test - void testKeysExcludesKeyAfterLastSubscriberCloses() { - SignalRegistry.Subscription sub = registry.subscribe(KEY); - sub.close(); - assertFalse(registry.keys().iterator().hasNext()); + void testAwaitAny_shutdownFires() throws Exception { + var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + var onCancelled = SignalRegistry.never(); + var onShutdown = registry.subscribe(new SignalKey.Shutdown()); + + registry.signal(new SignalKey.Shutdown()); + + assertEquals( + WakeReason.SHUTDOWN, + SignalRegistry.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); } - private static boolean iterableContains(Iterable keys, SignalKey target) { - for (SignalKey k : keys) { - if (k.equals(target)) return true; - } - return false; + @Test + void testAwaitAny_timeout() throws Exception { + var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + var onCancelled = SignalRegistry.never(); + var onShutdown = SignalRegistry.never(); + + assertEquals( + WakeReason.TIMEOUT, + SignalRegistry.awaitAny(Duration.ofMillis(50), onNotify, onCancelled, onShutdown)); } // --- anyOf determination via isDone (Option A) --- @Test void testCheckIsDone_notifyFires() throws Exception { - SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); - SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); + var notifyKey = new SignalKey.Event("wf-1", "topic"); + var cancelKey = new SignalKey.Cancellation("wf-1"); - CompletableFuture onNotify = registry.subscribe(notifyKey); - CompletableFuture onCancelled = registry.subscribe(cancelKey); - CompletableFuture onDbClosed = SignalRegistry.never(); + var onNotify = registry.subscribe(notifyKey); + var onCancelled = registry.subscribe(cancelKey); + var onShutdown = SignalRegistry.never(); registry.signal(notifyKey); try { - CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(1, TimeUnit.SECONDS); + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); } catch (java.util.concurrent.TimeoutException ignored) { } - assertFalse(onCancelled.isDone() || onDbClosed.isDone()); + assertFalse(onCancelled.isDone() || onShutdown.isDone()); assertTrue(onNotify.isDone()); } @Test void testCheckIsDone_cancelledFires() throws Exception { - SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); - SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); + var notifyKey = new SignalKey.Event("wf-1", "topic"); + var cancelKey = new SignalKey.Cancellation("wf-1"); - CompletableFuture onNotify = registry.subscribe(notifyKey); - CompletableFuture onCancelled = registry.subscribe(cancelKey); - CompletableFuture onDbClosed = SignalRegistry.never(); + var onNotify = registry.subscribe(notifyKey); + var onCancelled = registry.subscribe(cancelKey); + var onShutdown = SignalRegistry.never(); registry.signal(cancelKey); try { - CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(1, TimeUnit.SECONDS); + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); } catch (java.util.concurrent.TimeoutException ignored) { } - assertTrue(onCancelled.isDone() || onDbClosed.isDone()); + assertTrue(onCancelled.isDone() || onShutdown.isDone()); assertFalse(onNotify.isDone()); } @Test - void testCheckIsDone_dbClosedFires() throws Exception { - CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = new CompletableFuture<>(); - onDbClosed.complete(null); + void testCheckIsDone_shutdownFires() throws Exception { + var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + var onCancelled = SignalRegistry.never(); + var onShutdown = registry.subscribe(new SignalKey.Shutdown()); + + registry.signal(new SignalKey.Shutdown()); try { - CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(1, TimeUnit.SECONDS); + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); } catch (java.util.concurrent.TimeoutException ignored) { } - assertTrue(onCancelled.isDone() || onDbClosed.isDone()); + assertTrue(onCancelled.isDone() || onShutdown.isDone()); assertFalse(onNotify.isDone()); } @Test void testCheckIsDone_timeout() throws Exception { - CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = SignalRegistry.never(); + var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + var onCancelled = SignalRegistry.never(); + var onShutdown = SignalRegistry.never(); try { - CompletableFuture.anyOf(onNotify, onCancelled, onDbClosed).get(50, TimeUnit.MILLISECONDS); + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(50, TimeUnit.MILLISECONDS); } catch (java.util.concurrent.TimeoutException ignored) { } - assertFalse(onCancelled.isDone() || onDbClosed.isDone()); + assertFalse(onCancelled.isDone() || onShutdown.isDone()); assertFalse(onNotify.isDone()); } // --- anyOf determination via tagged dispatch (Option B) --- + // WakeReason is now embedded in each Subscription — no thenApply needed. @Test void testTaggedDispatch_notifyFires() throws Exception { - SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); - SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); + var notifyKey = new SignalKey.Event("wf-1", "topic"); + var cancelKey = new SignalKey.Cancellation("wf-1"); - CompletableFuture onNotify = registry.subscribe(notifyKey); - CompletableFuture onCancelled = registry.subscribe(cancelKey); - CompletableFuture onDbClosed = SignalRegistry.never(); + var onNotify = registry.subscribe(notifyKey); + var onCancelled = registry.subscribe(cancelKey); + var onShutdown = SignalRegistry.never(); registry.signal(notifyKey); - enum WakeReason { - NOTIFY, - CANCELLED, - DB_CLOSED - } - WakeReason reason = + var reason = (WakeReason) - CompletableFuture.anyOf( - onNotify.thenApply(v -> WakeReason.NOTIFY), - onCancelled.thenApply(v -> WakeReason.CANCELLED), - onDbClosed.thenApply(v -> WakeReason.DB_CLOSED)) - .get(1, TimeUnit.SECONDS); + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); - assertEquals(WakeReason.NOTIFY, reason); + assertEquals(WakeReason.EVENT, reason); } @Test void testTaggedDispatch_cancelledFires() throws Exception { - SignalKey notifyKey = new SignalKey.Event("wf-1", "topic"); - SignalKey cancelKey = new SignalKey.Cancellation("wf-1"); + var notifyKey = new SignalKey.Event("wf-1", "topic"); + var cancelKey = new SignalKey.Cancellation("wf-1"); - CompletableFuture onNotify = registry.subscribe(notifyKey); - CompletableFuture onCancelled = registry.subscribe(cancelKey); - CompletableFuture onDbClosed = SignalRegistry.never(); + var onNotify = registry.subscribe(notifyKey); + var onCancelled = registry.subscribe(cancelKey); + var onShutdown = SignalRegistry.never(); registry.signal(cancelKey); - enum WakeReason { - NOTIFY, - CANCELLED, - DB_CLOSED - } - WakeReason reason = + var reason = (WakeReason) - CompletableFuture.anyOf( - onNotify.thenApply(v -> WakeReason.NOTIFY), - onCancelled.thenApply(v -> WakeReason.CANCELLED), - onDbClosed.thenApply(v -> WakeReason.DB_CLOSED)) - .get(1, TimeUnit.SECONDS); + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); assertEquals(WakeReason.CANCELLED, reason); } @Test - void testTaggedDispatch_dbClosedFires() throws Exception { - CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = new CompletableFuture<>(); - onDbClosed.complete(null); + void testTaggedDispatch_shutdownFires() throws Exception { + var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + var onCancelled = SignalRegistry.never(); + var onShutdown = registry.subscribe(new SignalKey.Shutdown()); - enum WakeReason { - NOTIFY, - CANCELLED, - DB_CLOSED - } - WakeReason reason = + registry.signal(new SignalKey.Shutdown()); + + var reason = (WakeReason) - CompletableFuture.anyOf( - onNotify.thenApply(v -> WakeReason.NOTIFY), - onCancelled.thenApply(v -> WakeReason.CANCELLED), - onDbClosed.thenApply(v -> WakeReason.DB_CLOSED)) - .get(1, TimeUnit.SECONDS); + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); - assertEquals(WakeReason.DB_CLOSED, reason); + assertEquals(WakeReason.SHUTDOWN, reason); } @Test void testTaggedDispatch_timeout() throws Exception { - CompletableFuture onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - CompletableFuture onCancelled = SignalRegistry.never(); - CompletableFuture onDbClosed = SignalRegistry.never(); - - enum WakeReason { - NOTIFY, - CANCELLED, - DB_CLOSED - } + var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); + var onCancelled = SignalRegistry.never(); + var onShutdown = SignalRegistry.never(); + WakeReason reason = null; try { reason = (WakeReason) - CompletableFuture.anyOf( - onNotify.thenApply(v -> WakeReason.NOTIFY), - onCancelled.thenApply(v -> WakeReason.CANCELLED), - onDbClosed.thenApply(v -> WakeReason.DB_CLOSED)) + CompletableFuture.anyOf(onNotify, onCancelled, onShutdown) .get(50, TimeUnit.MILLISECONDS); } catch (java.util.concurrent.TimeoutException ignored) { } assertFalse(onNotify.isDone()); - assertEquals(null, reason); + assertNull(reason); } } From afc03a80fea9eb3269afc04963de502f6da48b14 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Thu, 14 May 2026 18:08:55 -0700 Subject: [PATCH 09/27] refactor signal registry -> signal map. recv/getEvent currently stubbed out --- ...e.java => NotificationListenerSource.java} | 68 +-- .../transact/database/NotificationsDAO.java | 508 +++++++++--------- .../dev/dbos/transact/database/SignalKey.java | 39 ++ .../dev/dbos/transact/database/SignalMap.java | 68 +++ .../transact/database/SignalRegistry.java | 121 ----- .../dbos/transact/database/Subscription.java | 18 + .../transact/database/SystemDatabase.java | 25 +- ...alRegistryTest.java => SignalMapTest.java} | 243 ++++----- 8 files changed, 522 insertions(+), 568 deletions(-) rename transact/src/main/java/dev/dbos/transact/database/{NotificationService.java => NotificationListenerSource.java} (65%) create mode 100644 transact/src/main/java/dev/dbos/transact/database/SignalKey.java create mode 100644 transact/src/main/java/dev/dbos/transact/database/SignalMap.java delete mode 100644 transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java create mode 100644 transact/src/main/java/dev/dbos/transact/database/Subscription.java rename transact/src/test/java/dev/dbos/transact/database/{SignalRegistryTest.java => SignalMapTest.java} (61%) diff --git a/transact/src/main/java/dev/dbos/transact/database/NotificationService.java b/transact/src/main/java/dev/dbos/transact/database/NotificationListenerSource.java similarity index 65% rename from transact/src/main/java/dev/dbos/transact/database/NotificationService.java rename to transact/src/main/java/dev/dbos/transact/database/NotificationListenerSource.java index 2681b696..eb2a652c 100644 --- a/transact/src/main/java/dev/dbos/transact/database/NotificationService.java +++ b/transact/src/main/java/dev/dbos/transact/database/NotificationListenerSource.java @@ -1,13 +1,11 @@ package dev.dbos.transact.database; +import dev.dbos.transact.database.SystemDatabase.NotificationSource; + import java.sql.Connection; import java.sql.SQLException; import java.sql.Statement; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.ReentrantLock; import javax.sql.DataSource; @@ -16,35 +14,31 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class NotificationService { - - public static class LockConditionPair { - public final ReentrantLock lock = new ReentrantLock(); - public final Condition condition = lock.newCondition(); - } +class NotificationListenerSource implements NotificationSource { - private static final Logger logger = LoggerFactory.getLogger(NotificationService.class); + private static final Logger logger = LoggerFactory.getLogger(NotificationListenerSource.class); - private final Map notificationsMap = new ConcurrentHashMap<>(); - private final AtomicReference notificationListenerThread = new AtomicReference<>(null); private final DataSource dataSource; + private final AtomicReference notificationListenerThread = new AtomicReference<>(null); + private final SignalMap signalMap = new SignalMap<>(); - public NotificationService(DataSource dataSource) { + public NotificationListenerSource(DataSource dataSource) { this.dataSource = dataSource; } - public boolean registerNotificationCondition(String key, LockConditionPair pair) { - return notificationsMap.putIfAbsent(key, pair) == null; - } - - public LockConditionPair getOrCreateNotificationCondition(String key) { - return notificationsMap.computeIfAbsent(key, k -> new LockConditionPair()); + @Override + public Subscription subscribe(SignalKey.Message key) { + var strKey = "m::%s::%s".formatted(key.workflowId(), key.topic()); + return signalMap.subscribe(strKey, key.wakeReason()); } - public void unregisterNotificationCondition(String key) { - notificationsMap.remove(key); + @Override + public Subscription subscribe(SignalKey.Event key) { + var strKey = "e::%s::%s".formatted(key.workflowId(), key.key()); + return signalMap.subscribe(strKey, key.wakeReason()); } + @Override public void start() { Thread t = new Thread(this::notificationListener, "NotificationListener"); t.setDaemon(true); @@ -54,7 +48,8 @@ public void start() { } } - public void stop() { + @Override + public void close() { Thread t = notificationListenerThread.getAndSet(null); if (t != null) { t.interrupt(); @@ -65,7 +60,6 @@ public void stop() { } } - notificationsMap.clear(); logger.debug("Notification listener stopped"); } @@ -103,9 +97,8 @@ private void notificationListener() { logger.error("Received notification with null channel. Payload: {}", payload); } else switch (channel) { - case "dbos_notifications_channel" -> handleNotification(payload, "notifications"); - case "dbos_workflow_events_channel" -> - handleNotification(payload, "workflow_events"); + case "dbos_notifications_channel" -> signalMap.signal("m::" + payload); + case "dbos_workflow_events_channel" -> signalMap.signal("e::" + payload); default -> logger.error("Unknown NOTIFY channel: {}", channel); } } @@ -136,25 +129,4 @@ private void notificationListener() { } logger.debug("Notification listener thread exiting"); } - - private void handleNotification(String payload, String mapType) { - - logger.debug("Received notification for {}", payload); - - if (payload != null && !payload.isEmpty()) { - LockConditionPair pair = notificationsMap.get(payload); - if (pair != null) { - pair.lock.lock(); - try { - pair.condition.signalAll(); - } finally { - pair.lock.unlock(); - } - logger.debug("Signaled {} condition for {}", mapType, payload); - } else { - logger.debug("ConditionMap has no entry for {}", payload); - } - // If no condition found, we simply ignore the notification - } - } } diff --git a/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java b/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java index 3bee63fe..3d53d84e 100644 --- a/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java @@ -1,8 +1,8 @@ package dev.dbos.transact.database; import dev.dbos.transact.Constants; +import dev.dbos.transact.database.SystemDatabase.NotifcationRegistry; import dev.dbos.transact.exceptions.DBOSNonExistentWorkflowException; -import dev.dbos.transact.exceptions.DBOSWorkflowExecutionConflictException; import dev.dbos.transact.json.DBOSSerializer; import dev.dbos.transact.json.SerializationUtil; import dev.dbos.transact.workflow.NotificationInfo; @@ -10,15 +10,12 @@ import java.sql.Connection; import java.sql.PreparedStatement; -import java.sql.ResultSet; import java.sql.SQLException; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.List; -import java.util.Objects; import java.util.UUID; -import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -151,7 +148,7 @@ ON CONFLICT (message_uuid) DO NOTHING static Object recv( DbContext ctx, - NotificationService notificationService, + NotifcationRegistry notifcationRegistry, Duration dbPollingInterval, String workflowId, int stepId, @@ -160,144 +157,146 @@ static Object recv( Duration timeout) throws SQLException { - DBOSSerializer serializer = ctx.serializer(); - var startTime = System.currentTimeMillis(); - String functionName = "DBOS.recv"; - String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC; - - StepResult recordedOutput; - try (Connection c = ctx.getConnection()) { - recordedOutput = - StepsDAO.checkStepExecutionTxn(c, ctx.schema(), workflowId, stepId, functionName); - } - - if (recordedOutput != null) { - logger.debug("Replaying recv, id: {}, topic: {}", stepId, finalTopic); - if (recordedOutput.output() != null) { - return SerializationUtil.deserializeValue( - recordedOutput.output(), recordedOutput.serialization(), serializer); - } else { - throw new RuntimeException("No output recorded in the last recv"); - } - } else { - logger.debug("Running recv, wfid {}, id: {}, topic: {}", workflowId, stepId, finalTopic); - } - - String payload = workflowId + "::" + finalTopic; - var lockPair = new NotificationService.LockConditionPair(); - - double actualTimeout = timeout.toMillis(); - var targetTime = System.currentTimeMillis() + actualTimeout; - var checkedDBForSleep = false; - - try { - lockPair.lock.lock(); - boolean success = notificationService.registerNotificationCondition(payload, lockPair); - if (!success) { - throw new DBOSWorkflowExecutionConflictException(workflowId); - } - - while (true) { - if (ctx.isClosed()) throw new IllegalStateException("SystemDatabase is closed"); - boolean hasExistingNotification; - try (Connection conn = ctx.getConnection()) { - final String sql = - """ - SELECT topic FROM "%s".notifications - WHERE destination_uuid = ? AND topic = ? AND consumed = FALSE - """ - .formatted(ctx.schema()); - - try (PreparedStatement stmt = conn.prepareStatement(sql)) { - stmt.setString(1, workflowId); - stmt.setString(2, finalTopic); - try (ResultSet rs = stmt.executeQuery()) { - hasExistingNotification = rs.next(); - } - } - } - - if (hasExistingNotification) break; - - var nowTime = System.currentTimeMillis(); - - if (!checkedDBForSleep) { - actualTimeout = - StepsDAO.durableSleepDuration(ctx, workflowId, timeoutFunctionId, timeout).toMillis(); - checkedDBForSleep = true; - targetTime = nowTime + actualTimeout; - } - if (nowTime >= targetTime) break; - long timeoutMs = (long) Math.min(targetTime - nowTime, dbPollingInterval.toMillis()); - - try { - lockPair.condition.await(timeoutMs, TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Interrupted while waiting for message", e); - } - } - } finally { - lockPair.lock.unlock(); - notificationService.unregisterNotificationCondition(payload); - } - - try (Connection conn = ctx.getConnection()) { - conn.setAutoCommit(false); - - try { - final String sql = - """ - UPDATE "%1$s".notifications - SET consumed = TRUE - WHERE destination_uuid = ? - AND topic = ? - AND consumed = FALSE - AND message_uuid = ( - SELECT message_uuid FROM "%1$s".notifications - WHERE destination_uuid = ? - AND topic = ? - AND consumed = FALSE - ORDER BY created_at_epoch_ms ASC - LIMIT 1 - ) - RETURNING message, serialization - """ - .formatted(ctx.schema()); - - String serializedMessage = null; - String serialization = null; - try (PreparedStatement stmt = conn.prepareStatement(sql)) { - stmt.setString(1, workflowId); - stmt.setString(2, finalTopic); - stmt.setString(3, workflowId); - stmt.setString(4, finalTopic); - - try (ResultSet rs = stmt.executeQuery()) { - if (rs.next()) { - serializedMessage = rs.getString("message"); - serialization = rs.getString("serialization"); - } - } - } - - var recvdMessage = - SerializationUtil.deserializeValue(serializedMessage, serialization, serializer); - - StepResult output = - new StepResult( - workflowId, stepId, functionName, serializedMessage, null, null, serialization); - StepsDAO.recordStepResultTxn( - conn, ctx.schema(), output, startTime, System.currentTimeMillis()); - - conn.commit(); - return recvdMessage; - - } catch (Exception e) { - conn.rollback(); - throw e; - } - } + return null; + // DBOSSerializer serializer = ctx.serializer(); + // var startTime = System.currentTimeMillis(); + // String functionName = "DBOS.recv"; + // String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC; + + // StepResult recordedOutput; + // try (Connection c = ctx.getConnection()) { + // recordedOutput = + // StepsDAO.checkStepExecutionTxn(c, ctx.schema(), workflowId, stepId, functionName); + // } + + // if (recordedOutput != null) { + // logger.debug("Replaying recv, id: {}, topic: {}", stepId, finalTopic); + // if (recordedOutput.output() != null) { + // return SerializationUtil.deserializeValue( + // recordedOutput.output(), recordedOutput.serialization(), serializer); + // } else { + // throw new RuntimeException("No output recorded in the last recv"); + // } + // } else { + // logger.debug("Running recv, wfid {}, id: {}, topic: {}", workflowId, stepId, finalTopic); + // } + + // String payload = workflowId + "::" + finalTopic; + // var lockPair = new NotificationListenerService.LockConditionPair(); + + // double actualTimeout = timeout.toMillis(); + // var targetTime = System.currentTimeMillis() + actualTimeout; + // var checkedDBForSleep = false; + + // try { + // lockPair.lock.lock(); + // boolean success = notificationService.registerNotificationCondition(payload, lockPair); + // if (!success) { + // throw new DBOSWorkflowExecutionConflictException(workflowId); + // } + + // while (true) { + // if (ctx.isClosed()) throw new IllegalStateException("SystemDatabase is closed"); + // boolean hasExistingNotification; + // try (Connection conn = ctx.getConnection()) { + // final String sql = + // """ + // SELECT topic FROM "%s".notifications + // WHERE destination_uuid = ? AND topic = ? AND consumed = FALSE + // """ + // .formatted(ctx.schema()); + + // try (PreparedStatement stmt = conn.prepareStatement(sql)) { + // stmt.setString(1, workflowId); + // stmt.setString(2, finalTopic); + // try (ResultSet rs = stmt.executeQuery()) { + // hasExistingNotification = rs.next(); + // } + // } + // } + + // if (hasExistingNotification) break; + + // var nowTime = System.currentTimeMillis(); + + // if (!checkedDBForSleep) { + // actualTimeout = + // StepsDAO.durableSleepDuration(ctx, workflowId, timeoutFunctionId, + // timeout).toMillis(); + // checkedDBForSleep = true; + // targetTime = nowTime + actualTimeout; + // } + // if (nowTime >= targetTime) break; + // long timeoutMs = (long) Math.min(targetTime - nowTime, dbPollingInterval.toMillis()); + + // try { + // lockPair.condition.await(timeoutMs, TimeUnit.MILLISECONDS); + // } catch (InterruptedException e) { + // Thread.currentThread().interrupt(); + // throw new RuntimeException("Interrupted while waiting for message", e); + // } + // } + // } finally { + // lockPair.lock.unlock(); + // notificationService.unregisterNotificationCondition(payload); + // } + + // try (Connection conn = ctx.getConnection()) { + // conn.setAutoCommit(false); + + // try { + // final String sql = + // """ + // UPDATE "%1$s".notifications + // SET consumed = TRUE + // WHERE destination_uuid = ? + // AND topic = ? + // AND consumed = FALSE + // AND message_uuid = ( + // SELECT message_uuid FROM "%1$s".notifications + // WHERE destination_uuid = ? + // AND topic = ? + // AND consumed = FALSE + // ORDER BY created_at_epoch_ms ASC + // LIMIT 1 + // ) + // RETURNING message, serialization + // """ + // .formatted(ctx.schema()); + + // String serializedMessage = null; + // String serialization = null; + // try (PreparedStatement stmt = conn.prepareStatement(sql)) { + // stmt.setString(1, workflowId); + // stmt.setString(2, finalTopic); + // stmt.setString(3, workflowId); + // stmt.setString(4, finalTopic); + + // try (ResultSet rs = stmt.executeQuery()) { + // if (rs.next()) { + // serializedMessage = rs.getString("message"); + // serialization = rs.getString("serialization"); + // } + // } + // } + + // var recvdMessage = + // SerializationUtil.deserializeValue(serializedMessage, serialization, serializer); + + // StepResult output = + // new StepResult( + // workflowId, stepId, functionName, serializedMessage, null, null, serialization); + // StepsDAO.recordStepResultTxn( + // conn, ctx.schema(), output, startTime, System.currentTimeMillis()); + + // conn.commit(); + // return recvdMessage; + + // } catch (Exception e) { + // conn.rollback(); + // throw e; + // } + // } } private static void setEvent( @@ -408,7 +407,7 @@ static void setEvent( static Object getEvent( DbContext ctx, - NotificationService notificationService, + NotifcationRegistry notifcationRegistry, Duration dbPollingInterval, String targetUuid, String key, @@ -416,115 +415,118 @@ static Object getEvent( GetWorkflowEventContext callerCtx) throws SQLException { - DBOSSerializer serializer = ctx.serializer(); - var startTime = System.currentTimeMillis(); - String functionName = "DBOS.getEvent"; - - if (callerCtx != null) { - StepResult recordedOutput; - try (Connection conn = ctx.getConnection()) { - recordedOutput = - StepsDAO.checkStepExecutionTxn( - conn, ctx.schema(), callerCtx.workflowId(), callerCtx.functionId(), functionName); - } - - if (recordedOutput != null) { - logger.debug("Replaying getEvent, id: {}, key: {}", callerCtx.functionId(), key); - if (recordedOutput.output() != null) { - return SerializationUtil.deserializeValue( - recordedOutput.output(), recordedOutput.serialization(), serializer); - } else { - throw new RuntimeException("No output recorded in the last getEvent"); - } - } else { - logger.debug("Running getEvent, id: {}, key: {}", callerCtx.functionId(), key); - } - } - - String payload = targetUuid + "::" + key; - NotificationService.LockConditionPair lockConditionPair = - notificationService.getOrCreateNotificationCondition(payload); - - lockConditionPair.lock.lock(); - try { - Object value = null; - final String sql = - """ - SELECT value, serialization FROM "%s".workflow_events WHERE workflow_uuid = ? AND key = ? - """ - .formatted(ctx.schema()); - - double actualTimeout = - Objects.requireNonNull(timeout, "getEvent timeout cannot be null").toMillis(); - var targetTime = System.currentTimeMillis() + actualTimeout; - var checkedDBForSleep = false; - var hasExistingNotification = false; - - while (true) { - if (ctx.isClosed()) throw new IllegalStateException("SystemDatabase is closed"); - try (Connection conn = ctx.getConnection(); - PreparedStatement stmt = conn.prepareStatement(sql)) { - - stmt.setString(1, targetUuid); - stmt.setString(2, key); - - try (ResultSet rs = stmt.executeQuery()) { - if (rs.next()) { - String serializedValue = rs.getString("value"); - String serialization = rs.getString("serialization"); - value = - SerializationUtil.deserializeValue(serializedValue, serialization, serializer); - hasExistingNotification = true; - } - } - } - - if (hasExistingNotification) break; - var nowTime = System.currentTimeMillis(); - if (nowTime > targetTime) break; - - if (callerCtx != null && !checkedDBForSleep) { - actualTimeout = - StepsDAO.durableSleepDuration( - ctx, callerCtx.workflowId(), callerCtx.timeoutFunctionId(), timeout) - .toMillis(); - targetTime = System.currentTimeMillis() + actualTimeout; - checkedDBForSleep = true; - if (nowTime > targetTime) break; - } - - try { - long timeoutms = (long) (targetTime - nowTime); - logger.debug("Waiting for notification {}...", timeout); - lockConditionPair.condition.await( - Math.min(timeoutms, dbPollingInterval.toMillis()), TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Interrupted while waiting for event", e); - } - } - - if (callerCtx != null) { - var toSaveSer = SerializationUtil.serializeValue(value, null, serializer); - StepResult output = - new StepResult( - callerCtx.workflowId(), - callerCtx.functionId(), - functionName, - null, - null, - null, - toSaveSer.serialization()) - .withOutput(toSaveSer.serializedValue()); - StepsDAO.recordStepResultTxn(ctx, output, startTime, System.currentTimeMillis()); - } - - return value; - - } finally { - lockConditionPair.lock.unlock(); - notificationService.unregisterNotificationCondition(payload); - } + return null; + // DBOSSerializer serializer = ctx.serializer(); + // var startTime = System.currentTimeMillis(); + // String functionName = "DBOS.getEvent"; + + // if (callerCtx != null) { + // StepResult recordedOutput; + // try (Connection conn = ctx.getConnection()) { + // recordedOutput = + // StepsDAO.checkStepExecutionTxn( + // conn, ctx.schema(), callerCtx.workflowId(), callerCtx.functionId(), + // functionName); + // } + + // if (recordedOutput != null) { + // logger.debug("Replaying getEvent, id: {}, key: {}", callerCtx.functionId(), key); + // if (recordedOutput.output() != null) { + // return SerializationUtil.deserializeValue( + // recordedOutput.output(), recordedOutput.serialization(), serializer); + // } else { + // throw new RuntimeException("No output recorded in the last getEvent"); + // } + // } else { + // logger.debug("Running getEvent, id: {}, key: {}", callerCtx.functionId(), key); + // } + // } + + // String payload = targetUuid + "::" + key; + // NotificationListenerService.LockConditionPair lockConditionPair = + // notificationService.getOrCreateNotificationCondition(payload); + + // lockConditionPair.lock.lock(); + // try { + // Object value = null; + // final String sql = + // """ + // SELECT value, serialization FROM "%s".workflow_events WHERE workflow_uuid = ? AND key + // = ? + // """ + // .formatted(ctx.schema()); + + // double actualTimeout = + // Objects.requireNonNull(timeout, "getEvent timeout cannot be null").toMillis(); + // var targetTime = System.currentTimeMillis() + actualTimeout; + // var checkedDBForSleep = false; + // var hasExistingNotification = false; + + // while (true) { + // if (ctx.isClosed()) throw new IllegalStateException("SystemDatabase is closed"); + // try (Connection conn = ctx.getConnection(); + // PreparedStatement stmt = conn.prepareStatement(sql)) { + + // stmt.setString(1, targetUuid); + // stmt.setString(2, key); + + // try (ResultSet rs = stmt.executeQuery()) { + // if (rs.next()) { + // String serializedValue = rs.getString("value"); + // String serialization = rs.getString("serialization"); + // value = + // SerializationUtil.deserializeValue(serializedValue, serialization, serializer); + // hasExistingNotification = true; + // } + // } + // } + + // if (hasExistingNotification) break; + // var nowTime = System.currentTimeMillis(); + // if (nowTime > targetTime) break; + + // if (callerCtx != null && !checkedDBForSleep) { + // actualTimeout = + // StepsDAO.durableSleepDuration( + // ctx, callerCtx.workflowId(), callerCtx.timeoutFunctionId(), timeout) + // .toMillis(); + // targetTime = System.currentTimeMillis() + actualTimeout; + // checkedDBForSleep = true; + // if (nowTime > targetTime) break; + // } + + // try { + // long timeoutms = (long) (targetTime - nowTime); + // logger.debug("Waiting for notification {}...", timeout); + // lockConditionPair.condition.await( + // Math.min(timeoutms, dbPollingInterval.toMillis()), TimeUnit.MILLISECONDS); + // } catch (InterruptedException e) { + // Thread.currentThread().interrupt(); + // throw new RuntimeException("Interrupted while waiting for event", e); + // } + // } + + // if (callerCtx != null) { + // var toSaveSer = SerializationUtil.serializeValue(value, null, serializer); + // StepResult output = + // new StepResult( + // callerCtx.workflowId(), + // callerCtx.functionId(), + // functionName, + // null, + // null, + // null, + // toSaveSer.serialization()) + // .withOutput(toSaveSer.serializedValue()); + // StepsDAO.recordStepResultTxn(ctx, output, startTime, System.currentTimeMillis()); + // } + + // return value; + + // } finally { + // lockConditionPair.lock.unlock(); + // notificationService.unregisterNotificationCondition(payload); + // } } static List getAllNotifications(DbContext ctx, String workflowId) diff --git a/transact/src/main/java/dev/dbos/transact/database/SignalKey.java b/transact/src/main/java/dev/dbos/transact/database/SignalKey.java new file mode 100644 index 00000000..93ac1b5a --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/SignalKey.java @@ -0,0 +1,39 @@ +package dev.dbos.transact.database; + +sealed interface SignalKey + permits SignalKey.Cancellation, SignalKey.Event, SignalKey.Message, SignalKey.Shutdown { + + public enum WakeReason { + MESSAGE, + EVENT, + CANCELLED, + SHUTDOWN, + TIMEOUT + } + + WakeReason wakeReason(); + + record Cancellation(String workflowId) implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.CANCELLED; + } + } + + record Event(String workflowId, String key) implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.EVENT; + } + } + + record Message(String workflowId, String topic) implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.MESSAGE; + } + } + + record Shutdown() implements SignalKey { + public WakeReason wakeReason() { + return WakeReason.SHUTDOWN; + } + } +} diff --git a/transact/src/main/java/dev/dbos/transact/database/SignalMap.java b/transact/src/main/java/dev/dbos/transact/database/SignalMap.java new file mode 100644 index 00000000..46c3f18f --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/SignalMap.java @@ -0,0 +1,68 @@ +package dev.dbos.transact.database; + +import dev.dbos.transact.database.SignalKey.WakeReason; + +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; + +class SignalMap { + private static class Entry { + final CompletableFuture future = new CompletableFuture<>(); + final AtomicInteger refs = new AtomicInteger(1); + final WakeReason reason; + + public Entry(WakeReason reason) { + this.reason = Objects.requireNonNull(reason); + } + } + + private final ConcurrentHashMap map = new ConcurrentHashMap<>(); + + public Subscription subscribe(K key, WakeReason reason) { + var entry = + map.compute( + key, + (k, e) -> { + if (e != null) { + e.refs.incrementAndGet(); + return e; + } + return new Entry(reason); + }); + + var sub = + new Subscription( + () -> + map.compute(key, (k, e) -> e != null && e.refs.decrementAndGet() == 0 ? null : e)); + + entry.future.thenAccept(sub::complete); + return sub; + } + + public void signal(K key) { + var e = map.remove(key); + if (e != null) { + e.future.complete(e.reason); + } + } + + static WakeReason awaitAny(Duration timeout, Subscription... subscriptions) { + try { + return (WakeReason) + CompletableFuture.anyOf(subscriptions).get(timeout.toMillis(), TimeUnit.MILLISECONDS); + } catch (TimeoutException ignored) { + return WakeReason.TIMEOUT; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } +} diff --git a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java b/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java deleted file mode 100644 index 190f910f..00000000 --- a/transact/src/main/java/dev/dbos/transact/database/SignalRegistry.java +++ /dev/null @@ -1,121 +0,0 @@ -package dev.dbos.transact.database; - -import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicInteger; - -sealed interface SignalKey - permits SignalKey.Cancellation, SignalKey.Event, SignalKey.Message, SignalKey.Shutdown { - - WakeReason wakeReason(); - - record Cancellation(String workflowId) implements SignalKey { - public WakeReason wakeReason() { - return WakeReason.CANCELLED; - } - } - - record Event(String workflowId, String topic) implements SignalKey { - public WakeReason wakeReason() { - return WakeReason.EVENT; - } - } - - record Message(String workflowId, String topic) implements SignalKey { - public WakeReason wakeReason() { - return WakeReason.MESSAGE; - } - } - - record Shutdown() implements SignalKey { - public WakeReason wakeReason() { - return WakeReason.SHUTDOWN; - } - } -} - -enum WakeReason { - MESSAGE, - EVENT, - CANCELLED, - SHUTDOWN, - TIMEOUT -} - -class SignalRegistry { - - private static class Entry { - final CompletableFuture future = new CompletableFuture<>(); - final AtomicInteger refs = new AtomicInteger(1); - } - - static class Subscription extends CompletableFuture implements AutoCloseable { - private final Runnable onClose; - - Subscription(Runnable onClose) { - this.onClose = onClose; - } - - @Override - public void close() { - onClose.run(); - } - } - - private final ConcurrentHashMap map = new ConcurrentHashMap<>(); - - public Subscription subscribe(SignalKey key) { - Entry entry = - map.compute( - key, - (k, e) -> { - if (e != null) { - e.refs.incrementAndGet(); - return e; - } - return new Entry(); - }); - Subscription sub = - new Subscription( - () -> - map.compute( - key, - (k, e) -> { - if (e != null && e.refs.decrementAndGet() == 0) return null; - return e; - })); - entry.future.thenAccept(sub::complete); - return sub; - } - - public void signal(SignalKey key) { - Entry e = map.remove(key); - if (e != null) e.future.complete(key.wakeReason()); - } - - Iterable keys() { - return map.keySet(); - } - - static Subscription never() { - return new Subscription(() -> {}); - } - - static WakeReason awaitAny(Duration timeout, Subscription... subscriptions) { - try { - return (WakeReason) - CompletableFuture.anyOf(subscriptions).get(timeout.toMillis(), TimeUnit.MILLISECONDS); - } catch (TimeoutException ignored) { - return WakeReason.TIMEOUT; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } catch (ExecutionException e) { - throw new RuntimeException(e); - } - } -} diff --git a/transact/src/main/java/dev/dbos/transact/database/Subscription.java b/transact/src/main/java/dev/dbos/transact/database/Subscription.java new file mode 100644 index 00000000..1db3d157 --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/Subscription.java @@ -0,0 +1,18 @@ +package dev.dbos.transact.database; + +import dev.dbos.transact.database.SignalKey.WakeReason; + +import java.util.concurrent.CompletableFuture; + +class Subscription extends CompletableFuture implements AutoCloseable { + private final Runnable onClose; + + Subscription(Runnable onClose) { + this.onClose = onClose; + } + + @Override + public void close() { + onClose.run(); + } +} diff --git a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java index 5c1b8805..96c35727 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -35,6 +35,18 @@ public class SystemDatabase implements AutoCloseable { + public interface NotifcationRegistry { + Subscription subscribe(SignalKey.Message key); + + Subscription subscribe(SignalKey.Event key); + } + + public interface NotificationSource extends NotifcationRegistry { + void start(); + + void close(); + } + private static final Logger logger = LoggerFactory.getLogger(SystemDatabase.class); public static String sanitizeSchema(String schema) { @@ -45,7 +57,7 @@ public static String sanitizeSchema(String schema) { private final boolean created; private final AtomicBoolean closed = new AtomicBoolean(false); - private final NotificationService notificationService; + private final NotificationSource notificationSource; private Duration dbPollingInterval = Duration.ofSeconds(1); private static void validatePostgresDataSource(DataSource dataSource) { @@ -72,7 +84,8 @@ private SystemDatabase( this.ctx = new DbContext(dataSource, schema, serializer, this.closed::get); this.created = created; - notificationService = new NotificationService(dataSource); + // TODO: NotificationPollingService + notificationSource = new NotificationListenerSource(dataSource); } public SystemDatabase(String url, String user, String password, String schema) { @@ -141,14 +154,14 @@ public static HikariDataSource createDataSource(String url, String user, String @Override public void close() { closed.set(true); - notificationService.stop(); + notificationSource.close(); if (created && ctx.dataSource() instanceof HikariDataSource hikariDataSource) { hikariDataSource.close(); } } public void start() { - notificationService.start(); + notificationSource.start(); } void speedUpPollingForTest() { @@ -385,7 +398,7 @@ public Object recv( () -> NotificationsDAO.recv( ctx, - notificationService, + notificationSource, dbPollingInterval, workflowId, stepId, @@ -414,7 +427,7 @@ public Object getEvent( return dbRetry( () -> NotificationsDAO.getEvent( - ctx, notificationService, dbPollingInterval, targetId, key, timeout, callerCtx)); + ctx, notificationSource, dbPollingInterval, targetId, key, timeout, callerCtx)); } public void sleep(String workflowId, int functionId, Duration duration) { diff --git a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java b/transact/src/test/java/dev/dbos/transact/database/SignalMapTest.java similarity index 61% rename from transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java rename to transact/src/test/java/dev/dbos/transact/database/SignalMapTest.java index 10d5e865..d72988bd 100644 --- a/transact/src/test/java/dev/dbos/transact/database/SignalRegistryTest.java +++ b/transact/src/test/java/dev/dbos/transact/database/SignalMapTest.java @@ -7,6 +7,8 @@ import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; +import dev.dbos.transact.database.SignalKey.WakeReason; + import java.time.Duration; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; @@ -15,9 +17,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; -class SignalRegistryTest { +class SignalMapTest { - SignalRegistry registry; + SignalMap map; static final SignalKey KEY = new SignalKey.Cancellation("wf-1"); static final SignalKey KEY_A = new SignalKey.Cancellation("wf-a"); @@ -26,7 +28,11 @@ class SignalRegistryTest { @BeforeEach void setup() { - registry = new SignalRegistry(); + map = new SignalMap<>(); + } + + private static Subscription never() { + return new Subscription(() -> {}); } // --- SignalKey structural equality --- @@ -45,9 +51,9 @@ void testSignalKeyStructuralEquality() { @Test void testBasicSubscribeAndSignal() throws Exception { - var f = registry.subscribe(KEY); + var f = map.subscribe(KEY, KEY.wakeReason()); assertFalse(f.isDone()); - registry.signal(KEY); + map.signal(KEY); assertTrue(f.isDone()); assertFalse(f.isCompletedExceptionally()); assertEquals(WakeReason.CANCELLED, f.get()); @@ -55,15 +61,15 @@ void testBasicSubscribeAndSignal() throws Exception { @Test void testMultipleListenersOnSameKey() throws Exception { - var f1 = registry.subscribe(KEY); - var f2 = registry.subscribe(KEY); - var f3 = registry.subscribe(KEY); + var f1 = map.subscribe(KEY, KEY.wakeReason()); + var f2 = map.subscribe(KEY, KEY.wakeReason()); + var f3 = map.subscribe(KEY, KEY.wakeReason()); assertFalse(f1.isDone()); assertFalse(f2.isDone()); assertFalse(f3.isDone()); - registry.signal(KEY); + map.signal(KEY); assertEquals(WakeReason.CANCELLED, f1.get()); assertEquals(WakeReason.CANCELLED, f2.get()); @@ -72,13 +78,13 @@ void testMultipleListenersOnSameKey() throws Exception { @Test void testMultipleSubscriptionsInAnyOf() throws Exception { - var f1 = registry.subscribe(KEY_A); - var f2 = registry.subscribe(KEY_B); + var f1 = map.subscribe(KEY_A, KEY_A.wakeReason()); + var f2 = map.subscribe(KEY_B, KEY_B.wakeReason()); var anyOf = CompletableFuture.anyOf(f1, f2); assertFalse(anyOf.isDone()); - registry.signal(KEY_B); + map.signal(KEY_B); assertEquals(WakeReason.CANCELLED, (WakeReason) anyOf.get(1, TimeUnit.SECONDS)); assertTrue(f2.isDone()); @@ -87,10 +93,10 @@ void testMultipleSubscriptionsInAnyOf() throws Exception { @Test void testSignalOnlyWakesMatchingKey() { - var f1 = registry.subscribe(KEY_A); - var f2 = registry.subscribe(KEY_B); + var f1 = map.subscribe(KEY_A, KEY_A.wakeReason()); + var f2 = map.subscribe(KEY_B, KEY_B.wakeReason()); - registry.signal(KEY_A); + map.signal(KEY_A); assertTrue(f1.isDone()); assertFalse(f2.isDone()); @@ -101,10 +107,10 @@ void testDifferentKeyTypesWithSameFieldsDoNotCollide() { var eventKey = new SignalKey.Event("wf-1", "wf-1"); var cancellationKey = new SignalKey.Cancellation("wf-1"); - var f1 = registry.subscribe(eventKey); - var f2 = registry.subscribe(cancellationKey); + var f1 = map.subscribe(eventKey, eventKey.wakeReason()); + var f2 = map.subscribe(cancellationKey, cancellationKey.wakeReason()); - registry.signal(cancellationKey); + map.signal(cancellationKey); assertTrue(f2.isDone()); assertFalse(f1.isDone()); @@ -112,36 +118,36 @@ void testDifferentKeyTypesWithSameFieldsDoNotCollide() { @Test void testSignalBeforeSubscribeDoesNotWake() { - registry.signal(KEY); + map.signal(KEY); - var f = registry.subscribe(KEY); + var f = map.subscribe(KEY, KEY.wakeReason()); assertFalse(f.isDone()); - registry.signal(KEY); + map.signal(KEY); assertTrue(f.isDone()); } @Test void testSignalIsOneShot() { - var f1 = registry.subscribe(KEY); - registry.signal(KEY); + var f1 = map.subscribe(KEY, KEY.wakeReason()); + map.signal(KEY); assertTrue(f1.isDone()); - var f2 = registry.subscribe(KEY); + var f2 = map.subscribe(KEY, KEY.wakeReason()); assertFalse(f2.isDone()); } @Test void testWakeReasonByKeyType() throws Exception { - var msgSub = registry.subscribe(new SignalKey.Message("wf-1", "topic")); - var eventSub = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - var cancelSub = registry.subscribe(new SignalKey.Cancellation("wf-1")); - var shutdownSub = registry.subscribe(new SignalKey.Shutdown()); + var msgSub = map.subscribe(new SignalKey.Message("wf-1", "topic"), WakeReason.MESSAGE); + var eventSub = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var cancelSub = map.subscribe(new SignalKey.Cancellation("wf-1"), WakeReason.CANCELLED); + var shutdownSub = map.subscribe(new SignalKey.Shutdown(), WakeReason.SHUTDOWN); - registry.signal(new SignalKey.Message("wf-1", "topic")); - registry.signal(new SignalKey.Event("wf-1", "topic")); - registry.signal(new SignalKey.Cancellation("wf-1")); - registry.signal(new SignalKey.Shutdown()); + map.signal(new SignalKey.Message("wf-1", "topic")); + map.signal(new SignalKey.Event("wf-1", "topic")); + map.signal(new SignalKey.Cancellation("wf-1")); + map.signal(new SignalKey.Shutdown()); assertEquals(WakeReason.MESSAGE, msgSub.get()); assertEquals(WakeReason.EVENT, eventSub.get()); @@ -153,20 +159,20 @@ void testWakeReasonByKeyType() throws Exception { @Test void testCloseOneSubscriberDoesNotOrphanOthers() throws Exception { - var sub1 = registry.subscribe(KEY); - var sub2 = registry.subscribe(KEY); + var sub1 = map.subscribe(KEY, KEY.wakeReason()); + var sub2 = map.subscribe(KEY, KEY.wakeReason()); sub1.close(); - registry.signal(KEY); + map.signal(KEY); assertEquals(WakeReason.CANCELLED, sub2.get()); } @Test void testClosePreventsFutureFromBeingSignalled() throws Exception { - var sub = registry.subscribe(KEY); + var sub = map.subscribe(KEY, KEY.wakeReason()); sub.close(); - registry.signal(KEY); + map.signal(KEY); boolean completed = sub.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); @@ -175,60 +181,18 @@ void testClosePreventsFutureFromBeingSignalled() throws Exception { @Test void testNeverFutureNeverCompletes() throws Exception { - var f = SignalRegistry.never(); + var f = never(); assertFalse(f.isDone()); boolean completed = f.orTimeout(100, TimeUnit.MILLISECONDS).handle((v, ex) -> ex == null).get(); assertFalse(completed); } - // --- keys() --- - - @Test - void testKeysReflectsActiveSubscriptions() { - registry.subscribe(KEY_A); - registry.subscribe(KEY_B); - - var keys = registry.keys(); - assertTrue(iterableContains(keys, KEY_A)); - assertTrue(iterableContains(keys, KEY_B)); - } - - @Test - void testKeysExcludesSignalledKey() { - registry.subscribe(KEY_A); - registry.subscribe(KEY_B); - registry.signal(KEY_A); - - var keys = registry.keys(); - assertFalse(iterableContains(keys, KEY_A)); - assertTrue(iterableContains(keys, KEY_B)); - } - - @Test - void testKeysEmptyWhenNoSubscribers() { - assertFalse(registry.keys().iterator().hasNext()); - } - - @Test - void testKeysExcludesKeyAfterLastSubscriberCloses() { - var sub = registry.subscribe(KEY); - sub.close(); - assertFalse(registry.keys().iterator().hasNext()); - } - - private static boolean iterableContains(Iterable keys, SignalKey target) { - for (SignalKey k : keys) { - if (k.equals(target)) return true; - } - return false; - } - // --- Threading --- @Test void testSubscribeBeforeSignalFromAnotherThread() throws Exception { - var f = registry.subscribe(FOO); + var f = map.subscribe(FOO, FOO.wakeReason()); CompletableFuture.runAsync( () -> { @@ -237,7 +201,7 @@ void testSubscribeBeforeSignalFromAnotherThread() throws Exception { } catch (InterruptedException e) { Thread.currentThread().interrupt(); } - registry.signal(FOO); + map.signal(FOO); }); assertTimeoutPreemptively(Duration.ofSeconds(1), (Executable) f::get); @@ -249,7 +213,7 @@ void testSignalFromMainAfterBackgroundSubscribes() throws Exception { CompletableFuture.runAsync( () -> { - var f = registry.subscribe(FOO); + var f = map.subscribe(FOO, FOO.wakeReason()); try { backgroundDone.complete(f.get(500, TimeUnit.MILLISECONDS)); } catch (Exception e) { @@ -258,7 +222,7 @@ void testSignalFromMainAfterBackgroundSubscribes() throws Exception { }); Thread.sleep(100); - registry.signal(FOO); + map.signal(FOO); assertTimeoutPreemptively(Duration.ofSeconds(1), (Executable) backgroundDone::get); assertEquals(WakeReason.CANCELLED, backgroundDone.get()); @@ -272,7 +236,7 @@ void testMultipleSubscribersInSeparateThreads() throws Exception { CompletableFuture.runAsync( () -> { try { - done1.complete(registry.subscribe(FOO).get(500, TimeUnit.MILLISECONDS)); + done1.complete(map.subscribe(FOO, FOO.wakeReason()).get(500, TimeUnit.MILLISECONDS)); } catch (Exception e) { done1.completeExceptionally(e); } @@ -280,14 +244,14 @@ void testMultipleSubscribersInSeparateThreads() throws Exception { CompletableFuture.runAsync( () -> { try { - done2.complete(registry.subscribe(FOO).get(500, TimeUnit.MILLISECONDS)); + done2.complete(map.subscribe(FOO, FOO.wakeReason()).get(500, TimeUnit.MILLISECONDS)); } catch (Exception e) { done2.completeExceptionally(e); } }); Thread.sleep(100); - registry.signal(FOO); + map.signal(FOO); assertTimeoutPreemptively( Duration.ofSeconds(1), @@ -303,9 +267,9 @@ void testConcurrentSignalAndSubscribe() throws Exception { Duration.ofSeconds(5), () -> { for (int i = 0; i < 1000; i++) { - var r = new SignalRegistry(); - var sub = r.subscribe(KEY); - CompletableFuture.runAsync(() -> r.signal(KEY)); + var m = new SignalMap(); + var sub = m.subscribe(KEY, KEY.wakeReason()); + CompletableFuture.runAsync(() -> m.signal(KEY)); sub.get(1, TimeUnit.SECONDS); } }); @@ -315,52 +279,52 @@ void testConcurrentSignalAndSubscribe() throws Exception { @Test void testAwaitAny_notifyFires() throws Exception { - var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - var onCancelled = registry.subscribe(new SignalKey.Cancellation("wf-1")); - var onShutdown = SignalRegistry.never(); + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = map.subscribe(new SignalKey.Cancellation("wf-1"), WakeReason.CANCELLED); + var onShutdown = never(); - registry.signal(new SignalKey.Event("wf-1", "topic")); + map.signal(new SignalKey.Event("wf-1", "topic")); assertEquals( WakeReason.EVENT, - SignalRegistry.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); + SignalMap.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); } @Test void testAwaitAny_cancelledFires() throws Exception { - var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - var onCancelled = registry.subscribe(new SignalKey.Cancellation("wf-1")); - var onShutdown = SignalRegistry.never(); + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = map.subscribe(new SignalKey.Cancellation("wf-1"), WakeReason.CANCELLED); + var onShutdown = never(); - registry.signal(new SignalKey.Cancellation("wf-1")); + map.signal(new SignalKey.Cancellation("wf-1")); assertEquals( WakeReason.CANCELLED, - SignalRegistry.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); + SignalMap.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); } @Test void testAwaitAny_shutdownFires() throws Exception { - var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - var onCancelled = SignalRegistry.never(); - var onShutdown = registry.subscribe(new SignalKey.Shutdown()); + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = map.subscribe(new SignalKey.Shutdown(), WakeReason.SHUTDOWN); - registry.signal(new SignalKey.Shutdown()); + map.signal(new SignalKey.Shutdown()); assertEquals( WakeReason.SHUTDOWN, - SignalRegistry.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); + SignalMap.awaitAny(Duration.ofSeconds(1), onNotify, onCancelled, onShutdown)); } @Test void testAwaitAny_timeout() throws Exception { - var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - var onCancelled = SignalRegistry.never(); - var onShutdown = SignalRegistry.never(); + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = never(); assertEquals( WakeReason.TIMEOUT, - SignalRegistry.awaitAny(Duration.ofMillis(50), onNotify, onCancelled, onShutdown)); + SignalMap.awaitAny(Duration.ofMillis(50), onNotify, onCancelled, onShutdown)); } // --- anyOf determination via isDone (Option A) --- @@ -370,11 +334,11 @@ void testCheckIsDone_notifyFires() throws Exception { var notifyKey = new SignalKey.Event("wf-1", "topic"); var cancelKey = new SignalKey.Cancellation("wf-1"); - var onNotify = registry.subscribe(notifyKey); - var onCancelled = registry.subscribe(cancelKey); - var onShutdown = SignalRegistry.never(); + var onNotify = map.subscribe(notifyKey, notifyKey.wakeReason()); + var onCancelled = map.subscribe(cancelKey, cancelKey.wakeReason()); + var onShutdown = never(); - registry.signal(notifyKey); + map.signal(notifyKey); try { CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); @@ -390,11 +354,11 @@ void testCheckIsDone_cancelledFires() throws Exception { var notifyKey = new SignalKey.Event("wf-1", "topic"); var cancelKey = new SignalKey.Cancellation("wf-1"); - var onNotify = registry.subscribe(notifyKey); - var onCancelled = registry.subscribe(cancelKey); - var onShutdown = SignalRegistry.never(); + var onNotify = map.subscribe(notifyKey, notifyKey.wakeReason()); + var onCancelled = map.subscribe(cancelKey, cancelKey.wakeReason()); + var onShutdown = never(); - registry.signal(cancelKey); + map.signal(cancelKey); try { CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); @@ -407,11 +371,11 @@ void testCheckIsDone_cancelledFires() throws Exception { @Test void testCheckIsDone_shutdownFires() throws Exception { - var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - var onCancelled = SignalRegistry.never(); - var onShutdown = registry.subscribe(new SignalKey.Shutdown()); + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = map.subscribe(new SignalKey.Shutdown(), WakeReason.SHUTDOWN); - registry.signal(new SignalKey.Shutdown()); + map.signal(new SignalKey.Shutdown()); try { CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(1, TimeUnit.SECONDS); @@ -424,9 +388,9 @@ void testCheckIsDone_shutdownFires() throws Exception { @Test void testCheckIsDone_timeout() throws Exception { - var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - var onCancelled = SignalRegistry.never(); - var onShutdown = SignalRegistry.never(); + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = never(); try { CompletableFuture.anyOf(onNotify, onCancelled, onShutdown).get(50, TimeUnit.MILLISECONDS); @@ -438,18 +402,17 @@ void testCheckIsDone_timeout() throws Exception { } // --- anyOf determination via tagged dispatch (Option B) --- - // WakeReason is now embedded in each Subscription — no thenApply needed. @Test void testTaggedDispatch_notifyFires() throws Exception { var notifyKey = new SignalKey.Event("wf-1", "topic"); var cancelKey = new SignalKey.Cancellation("wf-1"); - var onNotify = registry.subscribe(notifyKey); - var onCancelled = registry.subscribe(cancelKey); - var onShutdown = SignalRegistry.never(); + var onNotify = map.subscribe(notifyKey, notifyKey.wakeReason()); + var onCancelled = map.subscribe(cancelKey, cancelKey.wakeReason()); + var onShutdown = never(); - registry.signal(notifyKey); + map.signal(notifyKey); var reason = (WakeReason) @@ -463,11 +426,11 @@ void testTaggedDispatch_cancelledFires() throws Exception { var notifyKey = new SignalKey.Event("wf-1", "topic"); var cancelKey = new SignalKey.Cancellation("wf-1"); - var onNotify = registry.subscribe(notifyKey); - var onCancelled = registry.subscribe(cancelKey); - var onShutdown = SignalRegistry.never(); + var onNotify = map.subscribe(notifyKey, notifyKey.wakeReason()); + var onCancelled = map.subscribe(cancelKey, cancelKey.wakeReason()); + var onShutdown = never(); - registry.signal(cancelKey); + map.signal(cancelKey); var reason = (WakeReason) @@ -478,11 +441,11 @@ void testTaggedDispatch_cancelledFires() throws Exception { @Test void testTaggedDispatch_shutdownFires() throws Exception { - var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - var onCancelled = SignalRegistry.never(); - var onShutdown = registry.subscribe(new SignalKey.Shutdown()); + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = map.subscribe(new SignalKey.Shutdown(), WakeReason.SHUTDOWN); - registry.signal(new SignalKey.Shutdown()); + map.signal(new SignalKey.Shutdown()); var reason = (WakeReason) @@ -493,9 +456,9 @@ void testTaggedDispatch_shutdownFires() throws Exception { @Test void testTaggedDispatch_timeout() throws Exception { - var onNotify = registry.subscribe(new SignalKey.Event("wf-1", "topic")); - var onCancelled = SignalRegistry.never(); - var onShutdown = SignalRegistry.never(); + var onNotify = map.subscribe(new SignalKey.Event("wf-1", "topic"), WakeReason.EVENT); + var onCancelled = never(); + var onShutdown = never(); WakeReason reason = null; try { From f2ac1fbc8aa0c02171c1e05a413457b0cb21c7c9 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Thu, 14 May 2026 22:13:29 -0700 Subject: [PATCH 10/27] recv + naming --- .../dev/dbos/transact/database/DbContext.java | 8 +- .../transact/database/NotificationsDAO.java | 251 ++++++++---------- .../dev/dbos/transact/database/StepsDAO.java | 45 +++- .../dbos/transact/database/StreamsDAO.java | 5 +- .../transact/database/SystemDatabase.java | 14 +- .../dbos/transact/database/WorkflowDAO.java | 4 +- .../dbos/transact/execution/DBOSExecutor.java | 6 +- .../workflow/internal/StepResult.java | 13 + .../transact/database/ImportExportTest.java | 4 +- 9 files changed, 179 insertions(+), 171 deletions(-) diff --git a/transact/src/main/java/dev/dbos/transact/database/DbContext.java b/transact/src/main/java/dev/dbos/transact/database/DbContext.java index 7bc024b2..fd4ad7ef 100644 --- a/transact/src/main/java/dev/dbos/transact/database/DbContext.java +++ b/transact/src/main/java/dev/dbos/transact/database/DbContext.java @@ -11,11 +11,13 @@ record DbContext( DataSource dataSource, String schema, DBOSSerializer serializer, BooleanSupplier closed) { - Connection getConnection() throws SQLException { + public Connection getConnection() throws SQLException { return dataSource.getConnection(); } - boolean isClosed() { - return closed.getAsBoolean(); + public void checkClosed() { + if (closed.getAsBoolean()) { + throw new IllegalStateException("Database is closed"); + } } } diff --git a/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java b/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java index 3d53d84e..3477ec8d 100644 --- a/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java @@ -10,11 +10,13 @@ import java.sql.Connection; import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.sql.SQLException; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.UUID; import org.slf4j.Logger; @@ -47,7 +49,7 @@ static void send( try { StepResult recordedOutput = - StepsDAO.checkStepExecutionTxn(conn, ctx.schema(), workflowId, stepId, functionName); + StepsDAO.checkStepResult(conn, ctx.schema(), workflowId, stepId, functionName); if (recordedOutput != null) { logger.debug( @@ -92,7 +94,7 @@ ON CONFLICT (message_uuid) DO NOTHING } var output = new StepResult(workflowId, stepId, functionName, null, null, null, null); - StepsDAO.recordStepResultTxn( + StepsDAO.recordStepResult( conn, ctx.schema(), output, startTime, System.currentTimeMillis()); conn.commit(); @@ -148,155 +150,132 @@ ON CONFLICT (message_uuid) DO NOTHING static Object recv( DbContext ctx, - NotifcationRegistry notifcationRegistry, - Duration dbPollingInterval, String workflowId, int stepId, - int timeoutFunctionId, + Duration timeout, + int timeoutStepId, String topic, - Duration timeout) + Duration dbPollingInterval, + NotifcationRegistry notifcationRegistry) throws SQLException { - return null; - // DBOSSerializer serializer = ctx.serializer(); - // var startTime = System.currentTimeMillis(); - // String functionName = "DBOS.recv"; - // String finalTopic = (topic != null) ? topic : Constants.DBOS_NULL_TOPIC; - - // StepResult recordedOutput; - // try (Connection c = ctx.getConnection()) { - // recordedOutput = - // StepsDAO.checkStepExecutionTxn(c, ctx.schema(), workflowId, stepId, functionName); - // } - - // if (recordedOutput != null) { - // logger.debug("Replaying recv, id: {}, topic: {}", stepId, finalTopic); - // if (recordedOutput.output() != null) { - // return SerializationUtil.deserializeValue( - // recordedOutput.output(), recordedOutput.serialization(), serializer); - // } else { - // throw new RuntimeException("No output recorded in the last recv"); - // } - // } else { - // logger.debug("Running recv, wfid {}, id: {}, topic: {}", workflowId, stepId, finalTopic); - // } + if (Objects.requireNonNull(workflowId).isEmpty()) { + throw new IllegalArgumentException("workflowId must not be empty"); + } - // String payload = workflowId + "::" + finalTopic; - // var lockPair = new NotificationListenerService.LockConditionPair(); + var stepName = "DBOS.recv"; + topic = Objects.requireNonNullElse(topic, Constants.DBOS_NULL_TOPIC); - // double actualTimeout = timeout.toMillis(); - // var targetTime = System.currentTimeMillis() + actualTimeout; - // var checkedDBForSleep = false; + var recordedResult = StepsDAO.checkStepResult(ctx, workflowId, stepId, stepName); + if (recordedResult != null) { + logger.debug( + "Replaying recv, workflowId: {}, stepId: {}, topic: {}", workflowId, stepId, topic); + if (recordedResult.output() != null) { + return recordedResult.toResult(ctx.serializer()); + } + logger.debug( + "Running recv, workflowId: {}, stepId: {}, topic: {}", workflowId, stepId, topic); + } - // try { - // lockPair.lock.lock(); - // boolean success = notificationService.registerNotificationCondition(payload, lockPair); - // if (!success) { - // throw new DBOSWorkflowExecutionConflictException(workflowId); - // } + var startTime = System.currentTimeMillis(); + var messageKey = new SignalKey.Message(workflowId, topic); + dbPollingInterval = Objects.requireNonNullElse(dbPollingInterval, Duration.ofSeconds(1)); - // while (true) { - // if (ctx.isClosed()) throw new IllegalStateException("SystemDatabase is closed"); - // boolean hasExistingNotification; - // try (Connection conn = ctx.getConnection()) { - // final String sql = - // """ - // SELECT topic FROM "%s".notifications - // WHERE destination_uuid = ? AND topic = ? AND consumed = FALSE - // """ - // .formatted(ctx.schema()); - - // try (PreparedStatement stmt = conn.prepareStatement(sql)) { - // stmt.setString(1, workflowId); - // stmt.setString(2, finalTopic); - // try (ResultSet rs = stmt.executeQuery()) { - // hasExistingNotification = rs.next(); - // } - // } - // } + try (var messageSignal = notifcationRegistry.subscribe(messageKey)) { + while (true) { + ctx.checkClosed(); + var sql = + """ + SELECT topic FROM "%s".notifications + WHERE destination_uuid = ? AND topic = ? AND consumed = FALSE + """ + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + stmt.setString(2, topic); + try (var rs = stmt.executeQuery()) { + if (rs.next()) { + // query for results + break; + } + } + } - // if (hasExistingNotification) break; + // check cancelled - // var nowTime = System.currentTimeMillis(); + var sleepDuration = StepsDAO.durableSleepDuration(ctx, workflowId, timeoutStepId, timeout); + if (sleepDuration.isNegative() || sleepDuration.isZero()) { + var output = SerializationUtil.serializeValue(null, null, ctx.serializer()); + var stepResult = StepResult.ofOutput(workflowId, stepId, stepName, output); + StepsDAO.recordStepResult(ctx, stepResult, startTime); + return null; + } - // if (!checkedDBForSleep) { - // actualTimeout = - // StepsDAO.durableSleepDuration(ctx, workflowId, timeoutFunctionId, - // timeout).toMillis(); - // checkedDBForSleep = true; - // targetTime = nowTime + actualTimeout; - // } - // if (nowTime >= targetTime) break; - // long timeoutMs = (long) Math.min(targetTime - nowTime, dbPollingInterval.toMillis()); + var loopDuration = + dbPollingInterval.compareTo(sleepDuration) <= 0 ? dbPollingInterval : sleepDuration; - // try { - // lockPair.condition.await(timeoutMs, TimeUnit.MILLISECONDS); - // } catch (InterruptedException e) { - // Thread.currentThread().interrupt(); - // throw new RuntimeException("Interrupted while waiting for message", e); - // } - // } - // } finally { - // lockPair.lock.unlock(); - // notificationService.unregisterNotificationCondition(payload); - // } - - // try (Connection conn = ctx.getConnection()) { - // conn.setAutoCommit(false); - - // try { - // final String sql = - // """ - // UPDATE "%1$s".notifications - // SET consumed = TRUE - // WHERE destination_uuid = ? - // AND topic = ? - // AND consumed = FALSE - // AND message_uuid = ( - // SELECT message_uuid FROM "%1$s".notifications - // WHERE destination_uuid = ? - // AND topic = ? - // AND consumed = FALSE - // ORDER BY created_at_epoch_ms ASC - // LIMIT 1 - // ) - // RETURNING message, serialization - // """ - // .formatted(ctx.schema()); - - // String serializedMessage = null; - // String serialization = null; - // try (PreparedStatement stmt = conn.prepareStatement(sql)) { - // stmt.setString(1, workflowId); - // stmt.setString(2, finalTopic); - // stmt.setString(3, workflowId); - // stmt.setString(4, finalTopic); + SignalMap.awaitAny(loopDuration, messageSignal); + } + } - // try (ResultSet rs = stmt.executeQuery()) { - // if (rs.next()) { - // serializedMessage = rs.getString("message"); - // serialization = rs.getString("serialization"); - // } - // } - // } + ctx.checkClosed(); + var sql = + """ + UPDATE "%1$s".notifications + SET consumed = TRUE + WHERE destination_uuid = ? + AND topic = ? + AND consumed = FALSE + AND message_uuid = ( + SELECT message_uuid FROM "%1$s".notifications + WHERE destination_uuid = ? + AND topic = ? + AND consumed = FALSE + ORDER BY created_at_epoch_ms ASC + LIMIT 1 + ) + RETURNING message, serialization + """ + .formatted(ctx.schema()); - // var recvdMessage = - // SerializationUtil.deserializeValue(serializedMessage, serialization, serializer); + try (var conn = ctx.getConnection()) { + conn.setAutoCommit(false); + try { + String serializedMessage = null; + String serialization = null; + try (PreparedStatement stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + stmt.setString(2, topic); + stmt.setString(3, workflowId); + stmt.setString(4, topic); + + // Note, if there are two executors running the same workflow waiting on the same recv, + // only the first one will return a row here. The second one get a null message but then + // throw a WorkflowExecutionConflictException when it records the step result. + try (ResultSet rs = stmt.executeQuery()) { + if (rs.next()) { + serializedMessage = rs.getString("message"); + serialization = rs.getString("serialization"); + } + } + } - // StepResult output = - // new StepResult( - // workflowId, stepId, functionName, serializedMessage, null, null, serialization); - // StepsDAO.recordStepResultTxn( - // conn, ctx.schema(), output, startTime, System.currentTimeMillis()); + var deserializedMessage = + SerializationUtil.deserializeValue(serializedMessage, serialization, ctx.serializer()); - // conn.commit(); - // return recvdMessage; + var output = + new StepResult( + workflowId, stepId, stepName, serializedMessage, null, null, serialization); + StepsDAO.recordStepResult(conn, ctx.schema(), output, startTime); - // } catch (Exception e) { - // conn.rollback(); - // throw e; - // } - // } + conn.commit(); + return deserializedMessage; + } catch (Exception e) { + conn.rollback(); + throw e; + } + } } private static void setEvent( @@ -366,8 +345,7 @@ static void setEvent( try { if (asStep) { var recordedOutput = - StepsDAO.checkStepExecutionTxn( - conn, ctx.schema(), workflowId, functionId, functionName); + StepsDAO.checkStepResult(conn, ctx.schema(), workflowId, functionId, functionName); if (recordedOutput != null) { logger.debug( "Replaying setEvent, workflow: {}, step: {}, key: {}", workflowId, functionId, key); @@ -391,8 +369,7 @@ static void setEvent( if (asStep) { StepResult output = new StepResult(workflowId, functionId, functionName, null, null, null, null); - StepsDAO.recordStepResultTxn( - conn, ctx.schema(), output, startTime, System.currentTimeMillis()); + StepsDAO.recordStepResult(conn, ctx.schema(), output, startTime); } conn.commit(); diff --git a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java b/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java index ab82dfe4..602969f4 100644 --- a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java @@ -25,16 +25,27 @@ private StepsDAO() {} private static final Logger logger = LoggerFactory.getLogger(StepsDAO.class); - static void recordStepResultTxn( + static void recordStepResult(DbContext ctx, StepResult result, long startTimeEpochMs) + throws SQLException { + recordStepResult(ctx, result, startTimeEpochMs, System.currentTimeMillis()); + } + + static void recordStepResult( DbContext ctx, StepResult result, long startTimeEpochMs, long endTimeEpochMs) throws SQLException { try (var conn = ctx.getConnection()) { - recordStepResultTxn(conn, ctx.schema(), result, startTimeEpochMs, endTimeEpochMs); + recordStepResult(conn, ctx.schema(), result, startTimeEpochMs, endTimeEpochMs); } DebugTriggers.debugTriggerPoint(DebugTriggers.DEBUG_TRIGGER_STEP_COMMIT); } - static void recordStepResultTxn( + static void recordStepResult( + Connection conn, String schema, StepResult result, long startTimeEpochMs) + throws SQLException { + recordStepResult(conn, schema, result, startTimeEpochMs, System.currentTimeMillis()); + } + + static void recordStepResult( Connection conn, String schema, StepResult result, Long startTimeEpochMs, Long endTimeEpochMs) throws SQLException { @@ -96,9 +107,16 @@ static void recordStepResultTxn( } } - static StepResult checkStepExecutionTxn( + static StepResult checkStepResult( + DbContext ctx, String workflowId, int functionId, String functionName) throws SQLException { + try (var conn = ctx.getConnection()) { + return checkStepResult(conn, ctx.schema(), workflowId, functionId, functionName); + } + } + + static StepResult checkStepResult( Connection conn, String schema, String workflowId, int functionId, String functionName) - throws SQLException, DBOSWorkflowCancelledException, DBOSUnexpectedStepException { + throws SQLException { Objects.requireNonNull(schema); final String sql = @@ -187,11 +205,11 @@ static List listWorkflowSteps( StringBuilder sqlBuilder = new StringBuilder( """ - SELECT function_id, function_name, output, error, child_workflow_id, started_at_epoch_ms, completed_at_epoch_ms, serialization - FROM "%s".operation_outputs - WHERE workflow_uuid = ? - ORDER BY function_id - """ + SELECT function_id, function_name, output, error, child_workflow_id, started_at_epoch_ms, completed_at_epoch_ms, serialization + FROM "%s".operation_outputs + WHERE workflow_uuid = ? + ORDER BY function_id + """ .formatted(schema)); if (limit != null) { @@ -303,7 +321,7 @@ static boolean patch(DbContext ctx, String workflowId, int functionId, String pa var checkpointName = getCheckpointName(conn, ctx.schema(), workflowId, functionId); if (checkpointName == null) { var output = new StepResult(workflowId, functionId, patchName, null, null, null, null); - recordStepResultTxn(conn, ctx.schema(), output, System.currentTimeMillis(), null); + recordStepResult(conn, ctx.schema(), output, System.currentTimeMillis(), null); return true; } else { return patchName.equals(checkpointName); @@ -331,8 +349,7 @@ static Duration durableSleepDuration( StepResult recordedOutput; try (var conn = ctx.getConnection()) { - recordedOutput = - checkStepExecutionTxn(conn, ctx.schema(), workflowUuid, functionId, functionName); + recordedOutput = checkStepResult(conn, ctx.schema(), workflowUuid, functionId, functionName); } long endTime; @@ -368,7 +385,7 @@ static Duration durableSleepDuration( null, null, serializedValue.serialization()); - recordStepResultTxn(ctx, output, startTime, (long) endTime); + recordStepResult(ctx, output, startTime, (long) endTime); } catch (DBOSWorkflowExecutionConflictException e) { logger.error("Error recording sleep", e); } diff --git a/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java b/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java index b0881766..1d428884 100644 --- a/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java @@ -44,8 +44,7 @@ static void writeStreamFromWorkflow( try { StepResult recordedOutput = - StepsDAO.checkStepExecutionTxn( - conn, ctx.schema(), workflowId, functionId, functionName); + StepsDAO.checkStepResult(conn, ctx.schema(), workflowId, functionId, functionName); if (recordedOutput != null) { logger.debug("Replaying writeStream, id: {}, key: {}", functionId, key); @@ -58,7 +57,7 @@ static void writeStreamFromWorkflow( insertStream(conn, ctx.schema(), workflowId, functionId, key, value, serializationFormat); var output = new StepResult(workflowId, functionId, functionName, null, null, null, null); - StepsDAO.recordStepResultTxn( + StepsDAO.recordStepResult( conn, ctx.schema(), output, startTime, System.currentTimeMillis()); conn.commit(); diff --git a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java index 96c35727..c8246b0e 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -322,20 +322,20 @@ public List getQueuePartitions(String queueName) { return dbRetry(() -> QueuesDAO.getQueuePartitions(ctx, queueName)); } - public StepResult checkStepExecutionTxn(String workflowId, int functionId, String functionName) { + public StepResult checkStepResult(String workflowId, int functionId, String functionName) { return dbRetry( () -> { try (Connection connection = ctx.getConnection()) { - return StepsDAO.checkStepExecutionTxn( + return StepsDAO.checkStepResult( connection, ctx.schema(), workflowId, functionId, functionName); } }); } - public void recordStepResultTxn(StepResult result, long startTime) { + public void recordStepResult(StepResult result, long startTime) { var et = System.currentTimeMillis(); - dbRetry(() -> StepsDAO.recordStepResultTxn(ctx, result, startTime, et)); + dbRetry(() -> StepsDAO.recordStepResult(ctx, result, startTime, et)); } public List listWorkflowSteps( @@ -398,13 +398,13 @@ public Object recv( () -> NotificationsDAO.recv( ctx, - notificationSource, - dbPollingInterval, workflowId, stepId, + timeout, timeoutStepId, topic, - timeout)); + dbPollingInterval, + notificationSource)); } public void setEvent( diff --git a/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java b/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java index d9fe41ad..9d8ac061 100644 --- a/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java @@ -839,7 +839,7 @@ static Result awaitWorkflowResult( .formatted(ctx.schema()); while (true) { - if (ctx.isClosed()) throw new IllegalStateException("SystemDatabase is closed"); + ctx.checkClosed(); try (Connection connection = ctx.getConnection(); PreparedStatement stmt = connection.prepareStatement(sql)) { @@ -895,7 +895,7 @@ static void recordChildWorkflow( new StepResult(parentId, functionId, functionName, null, null, null, null) .withChildWorkflowId(childId); try (var conn = ctx.getConnection()) { - StepsDAO.recordStepResultTxn(conn, ctx.schema(), result, null, null); + StepsDAO.recordStepResult(conn, ctx.schema(), result, null, null); } } diff --git a/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java b/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java index 51a42164..9caa2a28 100644 --- a/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java +++ b/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java @@ -1045,7 +1045,7 @@ private T runStepInternal( var stepId = ctx.getAndIncrementFunctionId(); logger.debug("executeStep #{} ({}) for workflow {}", stepId, options.name(), workflowId); - var prevResult = systemDatabase.checkStepExecutionTxn(workflowId, stepId, options.name()); + var prevResult = systemDatabase.checkStepResult(workflowId, stepId, options.name()); if (prevResult != null) { if (prevResult.error() != null) { var t = @@ -1122,7 +1122,7 @@ private T runStepInternal( serializedException.serializedValue(), childWorkflowId, serializedException.serialization()); - systemDatabase.recordStepResultTxn(stepResult, startTime); + systemDatabase.recordStepResult(stepResult, startTime); throw (E) exception; } else { logger.debug("executeStep #{} for workflow {} completed {}", stepId, workflowId, output); @@ -1137,7 +1137,7 @@ private T runStepInternal( null, childWorkflowId, serializedOutput.serialization()); - systemDatabase.recordStepResultTxn(stepResult, startTime); + systemDatabase.recordStepResult(stepResult, startTime); return output; } } diff --git a/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java b/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java index c706e912..0db440dd 100644 --- a/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java +++ b/transact/src/main/java/dev/dbos/transact/workflow/internal/StepResult.java @@ -2,6 +2,7 @@ import dev.dbos.transact.json.DBOSSerializer; import dev.dbos.transact.json.SerializationUtil; +import dev.dbos.transact.json.SerializationUtil.SerializedResult; public record StepResult( String workflowId, @@ -32,6 +33,18 @@ public StepResult withSerialization(String v) { return new StepResult(workflowId, stepId, stepName, output, error, childWorkflowId, v); } + public static StepResult ofOutput( + String workflowId, int stepId, String stepName, SerializedResult result) { + return new StepResult( + workflowId, stepId, stepName, result.serializedValue(), null, null, result.serialization()); + } + + public static StepResult ofError( + String workflowId, int stepId, String stepName, SerializedResult result) { + return new StepResult( + workflowId, stepId, stepName, null, result.serializedValue(), null, result.serialization()); + } + @SuppressWarnings("unchecked") public R toResult(DBOSSerializer serializer) throws E { if (error != null) { diff --git a/transact/src/test/java/dev/dbos/transact/database/ImportExportTest.java b/transact/src/test/java/dev/dbos/transact/database/ImportExportTest.java index 8cc4a218..4423e20a 100644 --- a/transact/src/test/java/dev/dbos/transact/database/ImportExportTest.java +++ b/transact/src/test/java/dev/dbos/transact/database/ImportExportTest.java @@ -59,8 +59,8 @@ private void createWorkflow(String wfId) throws SQLException { sysdb.recordWorkflowOutput(wfId, null); long now = System.currentTimeMillis(); - sysdb.recordStepResultTxn(new StepResult(wfId, 0, "step0"), now - 2000); - sysdb.recordStepResultTxn(new StepResult(wfId, 1, "step1"), now - 1000); + sysdb.recordStepResult(new StepResult(wfId, 0, "step0"), now - 2000); + sysdb.recordStepResult(new StepResult(wfId, 1, "step1"), now - 1000); // asStep=false: writes to workflow_events + workflow_events_history, not operation_outputs sysdb.setEvent(wfId, 0, "event-key-1", "event-val-1", false, null); From c18474057b7e99814a0c559200511b1a4233aa1a Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Fri, 15 May 2026 10:17:11 -0700 Subject: [PATCH 11/27] crdb migration test --- gradle/libs.versions.toml | 1 + transact/build.gradle.kts | 1 + .../transact/database/SystemDatabase.java | 16 +++ .../transact/migrations/MigrationManager.java | 68 +++++----- .../migrations/CockroachMigrationTest.java | 118 ++++++++++++++++++ .../dev/dbos/transact/utils/PgContainer.java | 2 +- 6 files changed, 177 insertions(+), 29 deletions(-) create mode 100644 transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index f9f4b167..154d52bc 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -70,6 +70,7 @@ spring-boot4-dependencies = { module = "org.springframework.boot:spring-boot-dep spring-boot4-test = { module = "org.springframework.boot:spring-boot-test", version.ref = "spring-boot4" } sqlite-jdbc = { module = "org.xerial:sqlite-jdbc", version.ref = "sqlite-jdbc" } system-stubs-jupiter = { module = "uk.org.webcompere:system-stubs-jupiter", version.ref = "system-stubs" } +testcontainers-cockroachdb = { module = "org.testcontainers:testcontainers-cockroachdb", version.ref = "testcontainers" } testcontainers-postgresql = { module = "org.testcontainers:testcontainers-postgresql", version.ref = "testcontainers" } [bundles] diff --git a/transact/build.gradle.kts b/transact/build.gradle.kts index 27bc382c..4eb6ac7e 100644 --- a/transact/build.gradle.kts +++ b/transact/build.gradle.kts @@ -48,6 +48,7 @@ dependencies { testImplementation(libs.rest.assured) testImplementation(libs.kryo) testImplementation(libs.maven.artifact) + testImplementation(libs.testcontainers.cockroachdb) testImplementation(libs.testcontainers.postgresql) } diff --git a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java index c8246b0e..fd6ca822 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -151,6 +151,22 @@ public static HikariDataSource createDataSource(String url, String user, String return new HikariDataSource(config); } + public static boolean isCockroach(DataSource dataSource) throws SQLException { + try (var conn = dataSource.getConnection()) { + return isCockroach(conn); + } + } + + public static boolean isCockroach(Connection conn) throws SQLException { + try (var stmt = conn.createStatement(); + var rs = stmt.executeQuery("SELECT version()")) { + if (rs.next()) { + return rs.getString(1).toLowerCase().contains("cockroachdb"); + } + } + return false; + } + @Override public void close() { closed.set(true); diff --git a/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java b/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java index 98786cb2..1fe7d686 100644 --- a/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java +++ b/transact/src/main/java/dev/dbos/transact/migrations/MigrationManager.java @@ -56,15 +56,7 @@ private static void runMigrations(DataSource ds, String schema, boolean useListe try (var conn = ds.getConnection()) { - var isCockroach = false; - try (var stmt = conn.createStatement(); - var rs = stmt.executeQuery("SELECT version()")) { - if (rs.next()) { - String version = rs.getString(1).toLowerCase(); - isCockroach = version.contains("cockroachdb"); - } - } - + var isCockroach = SystemDatabase.isCockroach(conn); if (isCockroach) { useListenNotify = false; } @@ -188,6 +180,19 @@ public static int getCurrentSysDbVersion(Connection conn, String schema) { return 0; } + private static boolean notificationsPrimaryKeyExists(Connection conn, String schema) + throws SQLException { + var sql = + "SELECT 1 FROM information_schema.table_constraints" + + " WHERE table_schema = ? AND table_name = 'notifications' AND constraint_type = 'PRIMARY KEY'"; + try (var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, schema); + try (var rs = stmt.executeQuery()) { + return rs.next(); + } + } + } + static void runDbosMigrations(Connection conn, String schema, List migrations) { Objects.requireNonNull(schema, "schema must not be null"); var lastApplied = getCurrentSysDbVersion(conn, schema); @@ -199,10 +204,27 @@ static void runDbosMigrations(Connection conn, String schema, List migra } logger.info("Applying DBOS system database schema migration {}", migrationIndex); - try (var stmt = conn.createStatement()) { - stmt.execute(migrations.get(i)); - } catch (SQLException e) { - throw new RuntimeException("Failed to run migration %d".formatted(migrationIndex), e); + + // Migration 10 adds a primary key to notifications. Skip the DDL if one already exists + // (guard for installs that were created before the primary key was added to migration 1). + boolean skipMigration = false; + if (migrationIndex == 10) { + try { + skipMigration = notificationsPrimaryKeyExists(conn, schema); + } catch (SQLException e) { + throw new RuntimeException("Failed to check notifications primary key", e); + } + if (skipMigration) { + logger.info("Migration 10 skipped, primary key already exists"); + } + } + + if (!skipMigration) { + try (var stmt = conn.createStatement()) { + stmt.execute(migrations.get(i)); + } catch (SQLException e) { + throw new RuntimeException("Failed to run migration %d".formatted(migrationIndex), e); + } } try { @@ -274,8 +296,8 @@ static String migration1(boolean useListenNotify) { output TEXT, error TEXT, executor_id TEXT, - created_at BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000::numeric)::bigint, - updated_at BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000::numeric)::bigint, + created_at BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000.0)::bigint, + updated_at BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000.0)::bigint, application_version TEXT, application_id TEXT, class_name VARCHAR(255) DEFAULT NULL, @@ -315,7 +337,7 @@ message_uuid TEXT NOT NULL DEFAULT gen_random_uuid() PRIMARY KEY, -- Built-in fu destination_uuid TEXT NOT NULL, topic TEXT, message TEXT NOT NULL, - created_at_epoch_ms BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000::numeric)::bigint, + created_at_epoch_ms BIGINT NOT NULL DEFAULT (EXTRACT(epoch FROM now()) * 1000.0)::bigint, FOREIGN KEY (destination_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) ON UPDATE CASCADE ON DELETE CASCADE ); @@ -421,7 +443,7 @@ FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) static final String MIGRATION_7 = """ - ALTER TABLE "%1$s"."workflow_status" ADD COLUMN "owner_xid" VARCHAR(40) DEFAULT NULL + ALTER TABLE "%1$s"."workflow_status" ADD COLUMN "owner_xid" TEXT DEFAULT NULL """; static final String MIGRATION_8 = @@ -445,17 +467,7 @@ FOREIGN KEY (workflow_uuid) REFERENCES "%1$s".workflow_status(workflow_uuid) static final String MIGRATION_10 = """ - DO $$ - BEGIN - IF NOT EXISTS ( - SELECT 1 FROM information_schema.table_constraints - WHERE table_schema = '%1$s' - AND table_name = 'notifications' - AND constraint_type = 'PRIMARY KEY' - ) THEN - ALTER TABLE "%1$s".notifications ADD PRIMARY KEY (message_uuid); - END IF; - END $$; + ALTER TABLE "%1$s".notifications ADD PRIMARY KEY (message_uuid); """; static final String MIGRATION_11 = diff --git a/transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java b/transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java new file mode 100644 index 00000000..69a61739 --- /dev/null +++ b/transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java @@ -0,0 +1,118 @@ +package dev.dbos.transact.migrations; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import dev.dbos.transact.config.DBOSConfig; +import dev.dbos.transact.database.SystemDatabase; + +import java.sql.Connection; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Isolated; +import org.testcontainers.cockroachdb.CockroachContainer; +import org.testcontainers.postgresql.PostgreSQLContainer; + +/** + * Tests that LISTEN/NOTIFY functions and triggers are created on PG and omitted on CRDB. Each test + * spins up its own container so these tests are always tied to the specified DB type regardless of + * how PgContainer is configured for the rest of the test suite. + */ +@Isolated +class CockroachMigrationTest { + + @Test + void testPg_isCockroachReturnsFalse() throws Exception { + try (var pg = new PostgreSQLContainer("postgres:latest")) { + pg.start(); + try (var ds = + SystemDatabase.createDataSource(pg.getJdbcUrl(), pg.getUsername(), pg.getPassword())) { + assertFalse(SystemDatabase.isCockroach(ds)); + try (Connection conn = ds.getConnection()) { + assertFalse(SystemDatabase.isCockroach(conn)); + } + } + } + } + + @Test + void testCrdb_isCockroachReturnsTrue() throws Exception { + try (var crdb = new CockroachContainer("cockroachdb/cockroach:latest")) { + crdb.start(); + try (var ds = + SystemDatabase.createDataSource( + crdb.getJdbcUrl(), crdb.getUsername(), crdb.getPassword())) { + assertTrue(SystemDatabase.isCockroach(ds)); + try (Connection conn = ds.getConnection()) { + assertTrue(SystemDatabase.isCockroach(conn)); + } + } + } + } + + @Test + void testPg_notifyFunctionsAndTriggersPresent() throws Exception { + try (var pg = new PostgreSQLContainer("postgres:latest")) { + pg.start(); + try (var ds = + SystemDatabase.createDataSource(pg.getJdbcUrl(), pg.getUsername(), pg.getPassword())) { + var config = DBOSConfig.defaults("migration-notify-test").withDataSource(ds); + MigrationManager.runMigrations(config); + + try (Connection conn = ds.getConnection()) { + var meta = conn.getMetaData(); + MigrationManagerTest.assertFunctionExists(meta, "notifications_function"); + MigrationManagerTest.assertFunctionExists(meta, "workflow_events_function"); + MigrationManagerTest.assertTriggerExists(conn, "dbos_notifications_trigger"); + MigrationManagerTest.assertTriggerExists(conn, "dbos_workflow_events_trigger"); + } + } + } + } + + @Test + void testPg_noListenNotify_notifyFunctionsAndTriggersAbsent() throws Exception { + try (var pg = new PostgreSQLContainer("postgres:latest")) { + pg.start(); + try (var ds = + SystemDatabase.createDataSource(pg.getJdbcUrl(), pg.getUsername(), pg.getPassword())) { + var config = + DBOSConfig.defaults("migration-notify-test") + .withDataSource(ds) + .withUseListenNotify(false); + MigrationManager.runMigrations(config); + + try (Connection conn = ds.getConnection()) { + var meta = conn.getMetaData(); + MigrationManagerTest.assertFunctionAbsent(conn, "notifications_function"); + MigrationManagerTest.assertFunctionAbsent(conn, "workflow_events_function"); + MigrationManagerTest.assertTriggerAbsent(conn, "dbos_notifications_trigger"); + MigrationManagerTest.assertTriggerAbsent(conn, "dbos_workflow_events_trigger"); + MigrationManagerTest.assertFunctionExists(meta, "enqueue_workflow"); + MigrationManagerTest.assertFunctionExists(meta, "send_message"); + } + } + } + } + + @Test + void testCrdb_notifyFunctionsAndTriggersAbsent() throws Exception { + try (var crdb = new CockroachContainer("cockroachdb/cockroach:latest")) { + crdb.start(); + try (var ds = + SystemDatabase.createDataSource( + crdb.getJdbcUrl(), crdb.getUsername(), crdb.getPassword())) { + // useListenNotify=true intentionally — MigrationManager must override it to false for CRDB + var config = DBOSConfig.defaults("migration-notify-test").withDataSource(ds); + MigrationManager.runMigrations(config); + + try (Connection conn = ds.getConnection()) { + MigrationManagerTest.assertFunctionAbsent(conn, "notifications_function"); + MigrationManagerTest.assertFunctionAbsent(conn, "workflow_events_function"); + MigrationManagerTest.assertTriggerAbsent(conn, "dbos_notifications_trigger"); + MigrationManagerTest.assertTriggerAbsent(conn, "dbos_workflow_events_trigger"); + } + } + } + } +} diff --git a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java index 2c1431dd..a466d222 100644 --- a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -38,7 +38,7 @@ static PostgreSQLContainer acquire() { PERMITS.acquire(); var container = POOL.poll(); if (container == null) { - container = new PostgreSQLContainer("postgres:18"); + container = new PostgreSQLContainer("postgres:latest"); container.start(); } return container; From 7e4cd09d88017d60439aa7d2829a98d1bca91a2d Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Fri, 15 May 2026 10:34:13 -0700 Subject: [PATCH 12/27] reorg sysdb types --- .../dev/dbos/transact/database/DbContext.java | 2 +- .../database/NotificationListenerSource.java | 3 + .../dbos/transact/database/Subscription.java | 18 ----- .../transact/database/SystemDatabase.java | 14 +++- .../{ => dao}/ApplicationVersionDAO.java | 14 ++-- .../database/{ => dao}/ExternalStateDAO.java | 12 ++-- .../database/{ => dao}/NotificationsDAO.java | 20 +++--- .../database/{ => dao}/QueuesDAO.java | 12 ++-- .../database/{ => dao}/SchedulesDAO.java | 23 ++++--- .../transact/database/{ => dao}/StepsDAO.java | 19 +++--- .../database/{ => dao}/StreamsDAO.java | 15 +++-- .../database/{ => dao}/WorkflowDAO.java | 66 +++++++++++-------- .../database/{ => signal}/SignalKey.java | 4 +- .../database/{ => signal}/SignalMap.java | 8 +-- .../database/signal/Subscription.java | 18 +++++ .../dbos/transact/database/SignalMapTest.java | 5 +- 16 files changed, 148 insertions(+), 105 deletions(-) delete mode 100644 transact/src/main/java/dev/dbos/transact/database/Subscription.java rename transact/src/main/java/dev/dbos/transact/database/{ => dao}/ApplicationVersionDAO.java (84%) rename transact/src/main/java/dev/dbos/transact/database/{ => dao}/ExternalStateDAO.java (90%) rename transact/src/main/java/dev/dbos/transact/database/{ => dao}/NotificationsDAO.java (97%) rename transact/src/main/java/dev/dbos/transact/database/{ => dao}/QueuesDAO.java (95%) rename transact/src/main/java/dev/dbos/transact/database/{ => dao}/SchedulesDAO.java (91%) rename transact/src/main/java/dev/dbos/transact/database/{ => dao}/StepsDAO.java (95%) rename transact/src/main/java/dev/dbos/transact/database/{ => dao}/StreamsDAO.java (92%) rename transact/src/main/java/dev/dbos/transact/database/{ => dao}/WorkflowDAO.java (96%) rename transact/src/main/java/dev/dbos/transact/database/{ => signal}/SignalKey.java (91%) rename transact/src/main/java/dev/dbos/transact/database/{ => signal}/SignalMap.java (88%) create mode 100644 transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java diff --git a/transact/src/main/java/dev/dbos/transact/database/DbContext.java b/transact/src/main/java/dev/dbos/transact/database/DbContext.java index fd4ad7ef..d92dbfce 100644 --- a/transact/src/main/java/dev/dbos/transact/database/DbContext.java +++ b/transact/src/main/java/dev/dbos/transact/database/DbContext.java @@ -8,7 +8,7 @@ import javax.sql.DataSource; -record DbContext( +public record DbContext( DataSource dataSource, String schema, DBOSSerializer serializer, BooleanSupplier closed) { public Connection getConnection() throws SQLException { diff --git a/transact/src/main/java/dev/dbos/transact/database/NotificationListenerSource.java b/transact/src/main/java/dev/dbos/transact/database/NotificationListenerSource.java index eb2a652c..87e2f8c9 100644 --- a/transact/src/main/java/dev/dbos/transact/database/NotificationListenerSource.java +++ b/transact/src/main/java/dev/dbos/transact/database/NotificationListenerSource.java @@ -1,6 +1,9 @@ package dev.dbos.transact.database; import dev.dbos.transact.database.SystemDatabase.NotificationSource; +import dev.dbos.transact.database.signal.SignalKey; +import dev.dbos.transact.database.signal.SignalMap; +import dev.dbos.transact.database.signal.Subscription; import java.sql.Connection; import java.sql.SQLException; diff --git a/transact/src/main/java/dev/dbos/transact/database/Subscription.java b/transact/src/main/java/dev/dbos/transact/database/Subscription.java deleted file mode 100644 index 1db3d157..00000000 --- a/transact/src/main/java/dev/dbos/transact/database/Subscription.java +++ /dev/null @@ -1,18 +0,0 @@ -package dev.dbos.transact.database; - -import dev.dbos.transact.database.SignalKey.WakeReason; - -import java.util.concurrent.CompletableFuture; - -class Subscription extends CompletableFuture implements AutoCloseable { - private final Runnable onClose; - - Subscription(Runnable onClose) { - this.onClose = onClose; - } - - @Override - public void close() { - onClose.run(); - } -} diff --git a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java index fd6ca822..a82943d1 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -2,6 +2,16 @@ import dev.dbos.transact.Constants; import dev.dbos.transact.config.DBOSConfig; +import dev.dbos.transact.database.dao.ApplicationVersionDAO; +import dev.dbos.transact.database.dao.ExternalStateDAO; +import dev.dbos.transact.database.dao.NotificationsDAO; +import dev.dbos.transact.database.dao.QueuesDAO; +import dev.dbos.transact.database.dao.SchedulesDAO; +import dev.dbos.transact.database.dao.StepsDAO; +import dev.dbos.transact.database.dao.StreamsDAO; +import dev.dbos.transact.database.dao.WorkflowDAO; +import dev.dbos.transact.database.signal.SignalKey; +import dev.dbos.transact.database.signal.Subscription; import dev.dbos.transact.exceptions.*; import dev.dbos.transact.json.DBOSSerializer; import dev.dbos.transact.workflow.ExportedWorkflow; @@ -252,11 +262,11 @@ private T dbRetry(SqlSupplier supplier) { } } - static Instant toInstant(Long epochMs) { + public static Instant toInstant(Long epochMs) { return epochMs != null ? Instant.ofEpochMilli(epochMs) : null; } - static Duration toDuration(Long ms) { + public static Duration toDuration(Long ms) { return ms != null ? Duration.ofMillis(ms) : null; } diff --git a/transact/src/main/java/dev/dbos/transact/database/ApplicationVersionDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/ApplicationVersionDAO.java similarity index 84% rename from transact/src/main/java/dev/dbos/transact/database/ApplicationVersionDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/ApplicationVersionDAO.java index 4ea44d52..fd7c6fbf 100644 --- a/transact/src/main/java/dev/dbos/transact/database/ApplicationVersionDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/ApplicationVersionDAO.java @@ -1,5 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; +import dev.dbos.transact.database.DbContext; import dev.dbos.transact.workflow.VersionInfo; import java.sql.SQLException; @@ -8,11 +9,12 @@ import java.util.List; import java.util.UUID; -class ApplicationVersionDAO { +public class ApplicationVersionDAO { private ApplicationVersionDAO() {} - static void createApplicationVersion(DbContext ctx, String versionName) throws SQLException { + public static void createApplicationVersion(DbContext ctx, String versionName) + throws SQLException { String sql = """ INSERT INTO "%s".application_versions (version_id, version_name) @@ -28,7 +30,7 @@ ON CONFLICT (version_name) DO NOTHING } } - static void updateApplicationVersionTimestamp( + public static void updateApplicationVersionTimestamp( DbContext ctx, String versionName, Instant newTimestamp) throws SQLException { String sql = """ @@ -45,7 +47,7 @@ static void updateApplicationVersionTimestamp( } } - static List listApplicationVersions(DbContext ctx) throws SQLException { + public static List listApplicationVersions(DbContext ctx) throws SQLException { String sql = """ SELECT version_id, version_name, version_timestamp, created_at @@ -69,7 +71,7 @@ static List listApplicationVersions(DbContext ctx) throws SQLExcept return results; } - static VersionInfo getLatestApplicationVersion(DbContext ctx) throws SQLException { + public static VersionInfo getLatestApplicationVersion(DbContext ctx) throws SQLException { String sql = """ SELECT version_id, version_name, version_timestamp, created_at diff --git a/transact/src/main/java/dev/dbos/transact/database/ExternalStateDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/ExternalStateDAO.java similarity index 90% rename from transact/src/main/java/dev/dbos/transact/database/ExternalStateDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/ExternalStateDAO.java index d231a31f..476e5070 100644 --- a/transact/src/main/java/dev/dbos/transact/database/ExternalStateDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/ExternalStateDAO.java @@ -1,4 +1,7 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; + +import dev.dbos.transact.database.DbContext; +import dev.dbos.transact.database.ExternalState; import java.math.BigDecimal; import java.math.BigInteger; @@ -6,11 +9,11 @@ import java.util.Objects; import java.util.Optional; -class ExternalStateDAO { +public class ExternalStateDAO { private ExternalStateDAO() {} - static Optional getExternalState( + public static Optional getExternalState( DbContext ctx, String service, String workflowName, String key) throws SQLException { final String sql = """ @@ -38,7 +41,8 @@ static Optional getExternalState( } } - static ExternalState upsertExternalState(DbContext ctx, ExternalState state) throws SQLException { + public static ExternalState upsertExternalState(DbContext ctx, ExternalState state) + throws SQLException { final var sql = """ INSERT INTO "%s".event_dispatch_kv ( diff --git a/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java similarity index 97% rename from transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java index 3477ec8d..4de20084 100644 --- a/transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java @@ -1,7 +1,11 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; import dev.dbos.transact.Constants; +import dev.dbos.transact.database.DbContext; +import dev.dbos.transact.database.GetWorkflowEventContext; import dev.dbos.transact.database.SystemDatabase.NotifcationRegistry; +import dev.dbos.transact.database.signal.SignalKey; +import dev.dbos.transact.database.signal.SignalMap; import dev.dbos.transact.exceptions.DBOSNonExistentWorkflowException; import dev.dbos.transact.json.DBOSSerializer; import dev.dbos.transact.json.SerializationUtil; @@ -22,13 +26,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class NotificationsDAO { +public class NotificationsDAO { private NotificationsDAO() {} private static final Logger logger = LoggerFactory.getLogger(NotificationsDAO.class); - static void send( + public static void send( DbContext ctx, String workflowId, int stepId, @@ -110,7 +114,7 @@ ON CONFLICT (message_uuid) DO NOTHING } } - static void sendDirect( + public static void sendDirect( DbContext ctx, String destinationId, Object message, @@ -148,7 +152,7 @@ ON CONFLICT (message_uuid) DO NOTHING } } - static Object recv( + public static Object recv( DbContext ctx, String workflowId, int stepId, @@ -323,7 +327,7 @@ ON CONFLICT (workflow_uuid, key, function_id) } } - static void setEvent( + public static void setEvent( DbContext ctx, String workflowId, int functionId, @@ -382,7 +386,7 @@ static void setEvent( } } - static Object getEvent( + public static Object getEvent( DbContext ctx, NotifcationRegistry notifcationRegistry, Duration dbPollingInterval, @@ -506,7 +510,7 @@ static Object getEvent( // } } - static List getAllNotifications(DbContext ctx, String workflowId) + public static List getAllNotifications(DbContext ctx, String workflowId) throws SQLException { DBOSSerializer serializer = ctx.serializer(); var sql = diff --git a/transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/QueuesDAO.java similarity index 95% rename from transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/QueuesDAO.java index 654b82d2..0c065766 100644 --- a/transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/QueuesDAO.java @@ -1,5 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; +import dev.dbos.transact.database.DbContext; import dev.dbos.transact.workflow.Queue; import dev.dbos.transact.workflow.WorkflowState; @@ -17,13 +18,13 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class QueuesDAO { +public class QueuesDAO { private QueuesDAO() {} private static final Logger logger = LoggerFactory.getLogger(QueuesDAO.class); - static List getAndStartQueuedWorkflows( + public static List getAndStartQueuedWorkflows( DbContext ctx, Queue queue, String executorId, String appVersion, String partitionKey) throws SQLException { @@ -244,7 +245,7 @@ THEN EXTRACT(epoch FROM NOW()) * 1000 + workflow_timeout_ms } } - static boolean clearQueueAssignment(DbContext ctx, String workflowId) throws SQLException { + public static boolean clearQueueAssignment(DbContext ctx, String workflowId) throws SQLException { final String sql = """ @@ -264,7 +265,8 @@ static boolean clearQueueAssignment(DbContext ctx, String workflowId) throws SQL } } - static List getQueuePartitions(DbContext ctx, String queueName) throws SQLException { + public static List getQueuePartitions(DbContext ctx, String queueName) + throws SQLException { final String sql = """ diff --git a/transact/src/main/java/dev/dbos/transact/database/SchedulesDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/SchedulesDAO.java similarity index 91% rename from transact/src/main/java/dev/dbos/transact/database/SchedulesDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/SchedulesDAO.java index 3619c5f4..e846940d 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SchedulesDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/SchedulesDAO.java @@ -1,5 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; +import dev.dbos.transact.database.DbContext; import dev.dbos.transact.execution.SchedulerService; import dev.dbos.transact.json.DBOSSerializer; import dev.dbos.transact.json.SerializationUtil; @@ -20,11 +21,11 @@ import java.util.StringJoiner; import java.util.UUID; -class SchedulesDAO { +public class SchedulesDAO { private SchedulesDAO() {} - static void createSchedule(DbContext ctx, WorkflowSchedule schedule) throws SQLException { + public static void createSchedule(DbContext ctx, WorkflowSchedule schedule) throws SQLException { try (Connection conn = ctx.getConnection()) { createSchedule(conn, ctx.schema(), ctx.serializer(), schedule); } @@ -80,7 +81,7 @@ static void createSchedule( } } - static List listSchedules( + public static List listSchedules( DbContext ctx, List statuses, List workflowNames, @@ -148,7 +149,8 @@ static List listSchedules( } } - static Optional getSchedule(DbContext ctx, String name) throws SQLException { + public static Optional getSchedule(DbContext ctx, String name) + throws SQLException { DBOSSerializer serializer = ctx.serializer(); String sql = """ @@ -172,11 +174,11 @@ static Optional getSchedule(DbContext ctx, String name) throws } } - static void pauseSchedule(DbContext ctx, String name) throws SQLException { + public static void pauseSchedule(DbContext ctx, String name) throws SQLException { setScheduleStatus(ctx, name, ScheduleStatus.PAUSED); } - static void resumeSchedule(DbContext ctx, String name) throws SQLException { + public static void resumeSchedule(DbContext ctx, String name) throws SQLException { setScheduleStatus(ctx, name, ScheduleStatus.ACTIVE); } @@ -196,7 +198,7 @@ private static void setScheduleStatus(DbContext ctx, String name, ScheduleStatus } } - static void updateScheduleLastFiredAt(DbContext ctx, String name, Instant lastFiredAt) + public static void updateScheduleLastFiredAt(DbContext ctx, String name, Instant lastFiredAt) throws SQLException { String sql = """ @@ -212,7 +214,7 @@ static void updateScheduleLastFiredAt(DbContext ctx, String name, Instant lastFi } } - static void deleteSchedule(DbContext ctx, String name) throws SQLException { + public static void deleteSchedule(DbContext ctx, String name) throws SQLException { try (var conn = ctx.getConnection()) { deleteSchedule(conn, ctx.schema(), name); } @@ -231,7 +233,8 @@ static void deleteSchedule(Connection conn, String schema, String name) throws S } } - static void applySchedules(DbContext ctx, List schedules) throws SQLException { + public static void applySchedules(DbContext ctx, List schedules) + throws SQLException { try (var conn = ctx.getConnection()) { conn.setAutoCommit(false); try { diff --git a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/StepsDAO.java similarity index 95% rename from transact/src/main/java/dev/dbos/transact/database/StepsDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/StepsDAO.java index 602969f4..45015c3c 100644 --- a/transact/src/main/java/dev/dbos/transact/database/StepsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/StepsDAO.java @@ -1,5 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; +import dev.dbos.transact.database.DbContext; import dev.dbos.transact.exceptions.*; import dev.dbos.transact.internal.DebugTriggers; import dev.dbos.transact.json.DBOSSerializer; @@ -19,7 +20,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class StepsDAO { +public class StepsDAO { private StepsDAO() {} @@ -30,7 +31,7 @@ static void recordStepResult(DbContext ctx, StepResult result, long startTimeEpo recordStepResult(ctx, result, startTimeEpochMs, System.currentTimeMillis()); } - static void recordStepResult( + public static void recordStepResult( DbContext ctx, StepResult result, long startTimeEpochMs, long endTimeEpochMs) throws SQLException { try (var conn = ctx.getConnection()) { @@ -114,7 +115,7 @@ static StepResult checkStepResult( } } - static StepResult checkStepResult( + public static StepResult checkStepResult( Connection conn, String schema, String workflowId, int functionId, String functionName) throws SQLException { @@ -183,7 +184,7 @@ static StepResult checkStepResult( return recordedResult; } - static List listWorkflowSteps( + public static List listWorkflowSteps( DbContext ctx, String workflowId, Boolean loadOutput, Integer limit, Integer offset) throws SQLException { try (var conn = ctx.getConnection()) { @@ -279,7 +280,7 @@ static List listWorkflowSteps( return steps; } - static void sleep(DbContext ctx, String workflowUuid, int functionId, Duration duration) + public static void sleep(DbContext ctx, String workflowUuid, int functionId, Duration duration) throws SQLException { var sleepDuration = durableSleepDuration(ctx, workflowUuid, functionId, duration); logger.debug("Sleeping for duration {}", sleepDuration); @@ -314,7 +315,7 @@ static String getCheckpointName(Connection conn, String schema, String workflowI } } - static boolean patch(DbContext ctx, String workflowId, int functionId, String patchName) + public static boolean patch(DbContext ctx, String workflowId, int functionId, String patchName) throws SQLException { Objects.requireNonNull(patchName, "patchName cannot be null"); try (var conn = ctx.getConnection()) { @@ -329,8 +330,8 @@ static boolean patch(DbContext ctx, String workflowId, int functionId, String pa } } - static boolean deprecatePatch(DbContext ctx, String workflowId, int functionId, String patchName) - throws SQLException { + public static boolean deprecatePatch( + DbContext ctx, String workflowId, int functionId, String patchName) throws SQLException { Objects.requireNonNull(patchName, "patchName cannot be null"); try (var conn = ctx.getConnection()) { var checkpointName = getCheckpointName(conn, ctx.schema(), workflowId, functionId); diff --git a/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/StreamsDAO.java similarity index 92% rename from transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/StreamsDAO.java index 1d428884..02a7a05f 100644 --- a/transact/src/main/java/dev/dbos/transact/database/StreamsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/StreamsDAO.java @@ -1,5 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; +import dev.dbos.transact.database.DbContext; import dev.dbos.transact.json.SerializationUtil; import dev.dbos.transact.workflow.internal.StepResult; @@ -10,11 +11,11 @@ import java.util.List; import java.util.Map; -class StreamsDAO { +public class StreamsDAO { private StreamsDAO() {} - static void writeStreamFromStep( + public static void writeStreamFromStep( DbContext ctx, String workflowId, int functionId, @@ -27,7 +28,7 @@ static void writeStreamFromStep( } } - static void writeStreamFromWorkflow( + public static void writeStreamFromWorkflow( DbContext ctx, String workflowId, int functionId, @@ -125,13 +126,13 @@ SELECT COALESCE(MAX("offset"), -1) + 1 } } - static void closeStream(DbContext ctx, String workflowId, int functionId, String key) + public static void closeStream(DbContext ctx, String workflowId, int functionId, String key) throws SQLException { writeStreamFromWorkflow( ctx, workflowId, functionId, key, STREAM_CLOSED_SENTINEL, "portable_json"); } - static Object readStream(DbContext ctx, String workflowId, String key, int offset) + public static Object readStream(DbContext ctx, String workflowId, String key, int offset) throws SQLException { String sql = """ @@ -162,7 +163,7 @@ static Object readStream(DbContext ctx, String workflowId, String key, int offse } } - static Map> getAllStreamEntries(DbContext ctx, String workflowId) + public static Map> getAllStreamEntries(DbContext ctx, String workflowId) throws SQLException { String sql = """ diff --git a/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/WorkflowDAO.java similarity index 96% rename from transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java rename to transact/src/main/java/dev/dbos/transact/database/dao/WorkflowDAO.java index 9d8ac061..08408d81 100644 --- a/transact/src/main/java/dev/dbos/transact/database/WorkflowDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/WorkflowDAO.java @@ -1,6 +1,11 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.dao; import dev.dbos.transact.Constants; +import dev.dbos.transact.database.DbContext; +import dev.dbos.transact.database.MetricData; +import dev.dbos.transact.database.Result; +import dev.dbos.transact.database.SystemDatabase; +import dev.dbos.transact.database.WorkflowInitResult; import dev.dbos.transact.exceptions.DBOSAwaitedWorkflowCancelledException; import dev.dbos.transact.exceptions.DBOSConflictingWorkflowException; import dev.dbos.transact.exceptions.DBOSMaxRecoveryAttemptsExceededException; @@ -49,7 +54,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class WorkflowDAO { +public class WorkflowDAO { private static final Logger logger = LoggerFactory.getLogger(WorkflowDAO.class); @@ -69,7 +74,7 @@ class WorkflowDAO { private WorkflowDAO() {} - static WorkflowInitResult initWorkflowStatus( + public static WorkflowInitResult initWorkflowStatus( DbContext ctx, WorkflowStatusInternal initStatus, Integer maxRetries, @@ -346,7 +351,7 @@ static void updateWorkflowOutcome( * @param workflowId id of the workflow * @param result output serialized as json */ - static void recordWorkflowOutput(DbContext ctx, String workflowId, String result) + public static void recordWorkflowOutput(DbContext ctx, String workflowId, String result) throws SQLException { try (var conn = ctx.getConnection()) { @@ -360,7 +365,7 @@ static void recordWorkflowOutput(DbContext ctx, String workflowId, String result * @param workflowId id of the workflow * @param error output serialized as json */ - static void recordWorkflowError(DbContext ctx, String workflowId, String error) + public static void recordWorkflowError(DbContext ctx, String workflowId, String error) throws SQLException { try (var conn = ctx.getConnection()) { @@ -368,7 +373,8 @@ static void recordWorkflowError(DbContext ctx, String workflowId, String error) } } - static String getWorkflowSerialization(DbContext ctx, String workflowId) throws SQLException { + public static String getWorkflowSerialization(DbContext ctx, String workflowId) + throws SQLException { var sql = "SELECT serialization FROM \"%s\".workflow_status WHERE workflow_uuid = ?" .formatted(ctx.schema()); @@ -384,14 +390,15 @@ static String getWorkflowSerialization(DbContext ctx, String workflowId) throws return null; } - static WorkflowStatus getWorkflowStatus(DbContext ctx, String workflowId) throws SQLException { + public static WorkflowStatus getWorkflowStatus(DbContext ctx, String workflowId) + throws SQLException { try (var conn = ctx.getConnection()) { return getWorkflowStatus(conn, ctx.schema(), ctx.serializer(), workflowId); } } - static WorkflowStatus getWorkflowStatus( + public static WorkflowStatus getWorkflowStatus( Connection conn, String schema, DBOSSerializer serializer, String workflowId) throws SQLException { if (Objects.requireNonNull(workflowId, "workflowId must not be null").isEmpty()) { @@ -414,7 +421,7 @@ static WorkflowStatus getWorkflowStatus( return null; } - static void setWorkflowDelay(DbContext ctx, String workflowId, WorkflowDelay delay) + public static void setWorkflowDelay(DbContext ctx, String workflowId, WorkflowDelay delay) throws SQLException { Objects.requireNonNull(workflowId, "workflowId must not be null"); Objects.requireNonNull(delay, "delay must not be null"); @@ -449,7 +456,7 @@ static void setWorkflowDelay(DbContext ctx, String workflowId, WorkflowDelay del } } - static void transitionDelayedWorkflows(DbContext ctx) throws SQLException { + public static void transitionDelayedWorkflows(DbContext ctx) throws SQLException { var sql = """ UPDATE "%s".workflow_status @@ -469,7 +476,7 @@ static void transitionDelayedWorkflows(DbContext ctx) throws SQLException { } } - static List listWorkflows(DbContext ctx, ListWorkflowsInput input) + public static List listWorkflows(DbContext ctx, ListWorkflowsInput input) throws SQLException { DBOSSerializer serializer = ctx.serializer(); @@ -645,7 +652,7 @@ static List listWorkflows(DbContext ctx, ListWorkflowsInput inpu return workflows; } - static List getWorkflowAggregates( + public static List getWorkflowAggregates( DbContext ctx, GetWorkflowAggregatesInput input) throws SQLException { if (input == null) { @@ -815,7 +822,7 @@ private static WorkflowStatus resultsToWorkflowStatus( return info; } - static List getPendingWorkflows( + public static List getPendingWorkflows( DbContext ctx, List executorIds, String appVersion) throws SQLException { var input = new ListWorkflowsInput() @@ -826,7 +833,7 @@ static List getPendingWorkflows( } @SuppressWarnings("unchecked") - static Result awaitWorkflowResult( + public static Result awaitWorkflowResult( DbContext ctx, Duration dbPollingInterval, String workflowId) throws SQLException { DBOSSerializer serializer = ctx.serializer(); @@ -882,7 +889,7 @@ static Result awaitWorkflowResult( } } - static void recordChildWorkflow( + public static void recordChildWorkflow( DbContext ctx, String parentId, String childId, // workflowId of the child @@ -899,8 +906,8 @@ static void recordChildWorkflow( } } - static Optional checkChildWorkflow(DbContext ctx, String workflowUuid, int functionId) - throws SQLException { + public static Optional checkChildWorkflow( + DbContext ctx, String workflowUuid, int functionId) throws SQLException { final String sql = """ @@ -931,7 +938,7 @@ private static List filterNullsAndBlanks(List workflowIds) { return workflowIds.stream().filter(id -> id != null && !id.isBlank()).toList(); } - static void cancelWorkflows(DbContext ctx, List workflowIds) throws SQLException { + public static void cancelWorkflows(DbContext ctx, List workflowIds) throws SQLException { List filtered = filterNullsAndBlanks(workflowIds); if (filtered.isEmpty()) { return; @@ -964,7 +971,7 @@ AND status NOT IN (?, ?) } } - static void resumeWorkflows(DbContext ctx, List workflowIds, String queueName) + public static void resumeWorkflows(DbContext ctx, List workflowIds, String queueName) throws SQLException { List filtered = filterNullsAndBlanks(workflowIds); if (filtered.isEmpty()) { @@ -1002,8 +1009,8 @@ AND status NOT IN (?, ?) } } - static void deleteWorkflows(DbContext ctx, List workflowIds, boolean deleteChildren) - throws SQLException { + public static void deleteWorkflows( + DbContext ctx, List workflowIds, boolean deleteChildren) throws SQLException { List filtered = filterNullsAndBlanks(workflowIds); if (filtered.isEmpty()) { return; @@ -1036,7 +1043,8 @@ static void deleteWorkflows(DbContext ctx, List workflowIds, boolean del } } - static Set getWorkflowChildren(DbContext ctx, String workflowId) throws SQLException { + public static Set getWorkflowChildren(DbContext ctx, String workflowId) + throws SQLException { var children = new HashSet(); var toProcess = new ArrayDeque(); toProcess.add(workflowId); @@ -1069,7 +1077,7 @@ static Set getWorkflowChildren(DbContext ctx, String workflowId) throws return children; } - static String forkWorkflow( + public static String forkWorkflow( DbContext ctx, String originalWorkflowId, int startStep, ForkOptions options) throws SQLException { @@ -1307,7 +1315,7 @@ private static Instant getRowsCutoff(Connection conn, String schema, long rowsTh return null; } - static void garbageCollect(DbContext ctx, Instant cutoff, Long rowsThreshold) + public static void garbageCollect(DbContext ctx, Instant cutoff, Long rowsThreshold) throws SQLException { try (var conn = ctx.getConnection()) { @@ -1338,7 +1346,7 @@ static void garbageCollect(DbContext ctx, Instant cutoff, Long rowsThreshold) } } - static List getMetrics(DbContext ctx, Instant startTime, Instant endTime) + public static List getMetrics(DbContext ctx, Instant startTime, Instant endTime) throws SQLException { final var start = Objects.requireNonNull(startTime).toEpochMilli(); final var end = Objects.requireNonNull(endTime).toEpochMilli(); @@ -1469,7 +1477,7 @@ static List listWorkflowStreams(Connection conn, String schema, return streams; } - static List exportWorkflow( + public static List exportWorkflow( DbContext ctx, String workflowId, boolean exportChildren) throws SQLException { var workflowIds = @@ -1495,7 +1503,8 @@ static List exportWorkflow( return workflows; } - static void importWorkflow(DbContext ctx, List workflows) throws SQLException { + public static void importWorkflow(DbContext ctx, List workflows) + throws SQLException { DBOSSerializer serializer = ctx.serializer(); var wfSQL = @@ -1685,7 +1694,8 @@ static void importWorkflow(DbContext ctx, List workflows) thro } } - static Map getAllEvents(DbContext ctx, String workflowId) throws SQLException { + public static Map getAllEvents(DbContext ctx, String workflowId) + throws SQLException { try (var conn = ctx.getConnection()) { var events = listWorkflowEvents(conn, ctx.schema(), workflowId); var result = new LinkedHashMap(); diff --git a/transact/src/main/java/dev/dbos/transact/database/SignalKey.java b/transact/src/main/java/dev/dbos/transact/database/signal/SignalKey.java similarity index 91% rename from transact/src/main/java/dev/dbos/transact/database/SignalKey.java rename to transact/src/main/java/dev/dbos/transact/database/signal/SignalKey.java index 93ac1b5a..0524483a 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SignalKey.java +++ b/transact/src/main/java/dev/dbos/transact/database/signal/SignalKey.java @@ -1,6 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.signal; -sealed interface SignalKey +public sealed interface SignalKey permits SignalKey.Cancellation, SignalKey.Event, SignalKey.Message, SignalKey.Shutdown { public enum WakeReason { diff --git a/transact/src/main/java/dev/dbos/transact/database/SignalMap.java b/transact/src/main/java/dev/dbos/transact/database/signal/SignalMap.java similarity index 88% rename from transact/src/main/java/dev/dbos/transact/database/SignalMap.java rename to transact/src/main/java/dev/dbos/transact/database/signal/SignalMap.java index 46c3f18f..10cfb5ed 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SignalMap.java +++ b/transact/src/main/java/dev/dbos/transact/database/signal/SignalMap.java @@ -1,6 +1,6 @@ -package dev.dbos.transact.database; +package dev.dbos.transact.database.signal; -import dev.dbos.transact.database.SignalKey.WakeReason; +import dev.dbos.transact.database.signal.SignalKey.WakeReason; import java.time.Duration; import java.util.Objects; @@ -11,7 +11,7 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; -class SignalMap { +public class SignalMap { private static class Entry { final CompletableFuture future = new CompletableFuture<>(); final AtomicInteger refs = new AtomicInteger(1); @@ -52,7 +52,7 @@ public void signal(K key) { } } - static WakeReason awaitAny(Duration timeout, Subscription... subscriptions) { + public static WakeReason awaitAny(Duration timeout, Subscription... subscriptions) { try { return (WakeReason) CompletableFuture.anyOf(subscriptions).get(timeout.toMillis(), TimeUnit.MILLISECONDS); diff --git a/transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java b/transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java new file mode 100644 index 00000000..5cedc6da --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java @@ -0,0 +1,18 @@ +package dev.dbos.transact.database.signal; + +import dev.dbos.transact.database.signal.SignalKey.WakeReason; + +import java.util.concurrent.CompletableFuture; + +public class Subscription extends CompletableFuture implements AutoCloseable { + private final Runnable onClose; + + public Subscription(Runnable onClose) { + this.onClose = onClose; + } + + @Override + public void close() { + onClose.run(); + } +} diff --git a/transact/src/test/java/dev/dbos/transact/database/SignalMapTest.java b/transact/src/test/java/dev/dbos/transact/database/SignalMapTest.java index d72988bd..d051e001 100644 --- a/transact/src/test/java/dev/dbos/transact/database/SignalMapTest.java +++ b/transact/src/test/java/dev/dbos/transact/database/SignalMapTest.java @@ -7,7 +7,10 @@ import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; -import dev.dbos.transact.database.SignalKey.WakeReason; +import dev.dbos.transact.database.signal.SignalKey; +import dev.dbos.transact.database.signal.SignalKey.WakeReason; +import dev.dbos.transact.database.signal.SignalMap; +import dev.dbos.transact.database.signal.Subscription; import java.time.Duration; import java.util.concurrent.CompletableFuture; From e6f646b5e27d57796a2ab4af8b8a129fa64a5ecc Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Fri, 15 May 2026 12:49:56 -0700 Subject: [PATCH 13/27] reworked getEvent --- .../transact/database/GetEventCaller.java | 3 + .../database/GetWorkflowEventContext.java | 3 - .../transact/database/SystemDatabase.java | 5 +- .../database/dao/NotificationsDAO.java | 217 ++++++++---------- .../dbos/transact/execution/DBOSExecutor.java | 12 +- 5 files changed, 111 insertions(+), 129 deletions(-) create mode 100644 transact/src/main/java/dev/dbos/transact/database/GetEventCaller.java delete mode 100644 transact/src/main/java/dev/dbos/transact/database/GetWorkflowEventContext.java diff --git a/transact/src/main/java/dev/dbos/transact/database/GetEventCaller.java b/transact/src/main/java/dev/dbos/transact/database/GetEventCaller.java new file mode 100644 index 00000000..de7f514c --- /dev/null +++ b/transact/src/main/java/dev/dbos/transact/database/GetEventCaller.java @@ -0,0 +1,3 @@ +package dev.dbos.transact.database; + +public record GetEventCaller(String workflowId, int stepId, int timeoutStepId) {} diff --git a/transact/src/main/java/dev/dbos/transact/database/GetWorkflowEventContext.java b/transact/src/main/java/dev/dbos/transact/database/GetWorkflowEventContext.java deleted file mode 100644 index e183b462..00000000 --- a/transact/src/main/java/dev/dbos/transact/database/GetWorkflowEventContext.java +++ /dev/null @@ -1,3 +0,0 @@ -package dev.dbos.transact.database; - -public record GetWorkflowEventContext(String workflowId, int functionId, int timeoutFunctionId) {} diff --git a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java index a82943d1..0ed7e592 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -447,13 +447,12 @@ public void setEvent( ctx, workflowId, functionId, key, message, asStep, serialization)); } - public Object getEvent( - String targetId, String key, Duration timeout, GetWorkflowEventContext callerCtx) { + public Object getEvent(String targetId, String key, Duration timeout, GetEventCaller caller) { return dbRetry( () -> NotificationsDAO.getEvent( - ctx, notificationSource, dbPollingInterval, targetId, key, timeout, callerCtx)); + ctx, targetId, key, timeout, caller, dbPollingInterval, notificationSource)); } public void sleep(String workflowId, int functionId, Duration duration) { diff --git a/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java index 4de20084..a4be1053 100644 --- a/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java @@ -2,7 +2,7 @@ import dev.dbos.transact.Constants; import dev.dbos.transact.database.DbContext; -import dev.dbos.transact.database.GetWorkflowEventContext; +import dev.dbos.transact.database.GetEventCaller; import dev.dbos.transact.database.SystemDatabase.NotifcationRegistry; import dev.dbos.transact.database.signal.SignalKey; import dev.dbos.transact.database.signal.SignalMap; @@ -21,8 +21,11 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.UUID; +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -386,128 +389,108 @@ public static void setEvent( } } + private record GetEventResult(String value, String serialization) {} + + private static Optional getEvent( + DbContext ctx, @NonNull String workflowId, @NonNull String key) throws SQLException { + var sql = + """ + SELECT value, serialization FROM "%s".workflow_events WHERE workflow_uuid = ? AND key = ? + """ + .formatted(ctx.schema()); + try (var conn = ctx.getConnection(); + var stmt = conn.prepareStatement(sql)) { + stmt.setString(1, workflowId); + stmt.setString(2, key); + try (var rs = stmt.executeQuery()) { + if (rs.next()) { + var value = rs.getString("value"); + var serialization = rs.getString("serialization"); + return Optional.of(new GetEventResult(value, serialization)); + } + } + } + + return Optional.empty(); + } + public static Object getEvent( DbContext ctx, - NotifcationRegistry notifcationRegistry, - Duration dbPollingInterval, - String targetUuid, + String workflowId, String key, Duration timeout, - GetWorkflowEventContext callerCtx) + @Nullable GetEventCaller caller, + Duration dbPollingInterval, + NotifcationRegistry notifcationRegistry) throws SQLException { - return null; - // DBOSSerializer serializer = ctx.serializer(); - // var startTime = System.currentTimeMillis(); - // String functionName = "DBOS.getEvent"; - - // if (callerCtx != null) { - // StepResult recordedOutput; - // try (Connection conn = ctx.getConnection()) { - // recordedOutput = - // StepsDAO.checkStepExecutionTxn( - // conn, ctx.schema(), callerCtx.workflowId(), callerCtx.functionId(), - // functionName); - // } - - // if (recordedOutput != null) { - // logger.debug("Replaying getEvent, id: {}, key: {}", callerCtx.functionId(), key); - // if (recordedOutput.output() != null) { - // return SerializationUtil.deserializeValue( - // recordedOutput.output(), recordedOutput.serialization(), serializer); - // } else { - // throw new RuntimeException("No output recorded in the last getEvent"); - // } - // } else { - // logger.debug("Running getEvent, id: {}, key: {}", callerCtx.functionId(), key); - // } - // } - - // String payload = targetUuid + "::" + key; - // NotificationListenerService.LockConditionPair lockConditionPair = - // notificationService.getOrCreateNotificationCondition(payload); - - // lockConditionPair.lock.lock(); - // try { - // Object value = null; - // final String sql = - // """ - // SELECT value, serialization FROM "%s".workflow_events WHERE workflow_uuid = ? AND key - // = ? - // """ - // .formatted(ctx.schema()); - - // double actualTimeout = - // Objects.requireNonNull(timeout, "getEvent timeout cannot be null").toMillis(); - // var targetTime = System.currentTimeMillis() + actualTimeout; - // var checkedDBForSleep = false; - // var hasExistingNotification = false; - - // while (true) { - // if (ctx.isClosed()) throw new IllegalStateException("SystemDatabase is closed"); - // try (Connection conn = ctx.getConnection(); - // PreparedStatement stmt = conn.prepareStatement(sql)) { - - // stmt.setString(1, targetUuid); - // stmt.setString(2, key); - - // try (ResultSet rs = stmt.executeQuery()) { - // if (rs.next()) { - // String serializedValue = rs.getString("value"); - // String serialization = rs.getString("serialization"); - // value = - // SerializationUtil.deserializeValue(serializedValue, serialization, serializer); - // hasExistingNotification = true; - // } - // } - // } - - // if (hasExistingNotification) break; - // var nowTime = System.currentTimeMillis(); - // if (nowTime > targetTime) break; - - // if (callerCtx != null && !checkedDBForSleep) { - // actualTimeout = - // StepsDAO.durableSleepDuration( - // ctx, callerCtx.workflowId(), callerCtx.timeoutFunctionId(), timeout) - // .toMillis(); - // targetTime = System.currentTimeMillis() + actualTimeout; - // checkedDBForSleep = true; - // if (nowTime > targetTime) break; - // } - - // try { - // long timeoutms = (long) (targetTime - nowTime); - // logger.debug("Waiting for notification {}...", timeout); - // lockConditionPair.condition.await( - // Math.min(timeoutms, dbPollingInterval.toMillis()), TimeUnit.MILLISECONDS); - // } catch (InterruptedException e) { - // Thread.currentThread().interrupt(); - // throw new RuntimeException("Interrupted while waiting for event", e); - // } - // } - - // if (callerCtx != null) { - // var toSaveSer = SerializationUtil.serializeValue(value, null, serializer); - // StepResult output = - // new StepResult( - // callerCtx.workflowId(), - // callerCtx.functionId(), - // functionName, - // null, - // null, - // null, - // toSaveSer.serialization()) - // .withOutput(toSaveSer.serializedValue()); - // StepsDAO.recordStepResultTxn(ctx, output, startTime, System.currentTimeMillis()); - // } - - // return value; - - // } finally { - // lockConditionPair.lock.unlock(); - // notificationService.unregisterNotificationCondition(payload); - // } + if (Objects.requireNonNull(workflowId).isEmpty()) { + throw new IllegalArgumentException("workflowId must not be empty"); + } + + var stepName = "DBOS.getEvent"; + + if (caller != null) { + var prevResult = + StepsDAO.checkStepResult(ctx, caller.workflowId(), caller.stepId(), stepName); + if (prevResult != null) { + logger.debug("Replaying getEvent, id: {}, key: {}", caller.stepId(), key); + return prevResult.toResult(ctx.serializer()); + } + logger.debug("Running getEvent, id: {}, key: {}", caller.stepId(), key); + } + + var startTime = Instant.now(); + var eventKey = new SignalKey.Event(workflowId, key); + dbPollingInterval = Objects.requireNonNullElse(dbPollingInterval, Duration.ofSeconds(1)); + + GetEventResult result = null; + try (var eventSignal = notifcationRegistry.subscribe(eventKey)) { + while (true) { + var optResult = getEvent(ctx, workflowId, key); + if (optResult.isPresent()) { + result = optResult.get(); + break; + } + + // check cancelled (both workflowId and caller.workflowId) + + var sleepDuration = + caller != null + ? StepsDAO.durableSleepDuration( + ctx, caller.workflowId(), caller.timeoutStepId(), timeout) + : timeout.minus(Duration.between(startTime, Instant.now())); + + if (sleepDuration.isNegative() || sleepDuration.isZero()) { + result = new GetEventResult(null, null); + break; + } + + var loopDuration = + dbPollingInterval.compareTo(sleepDuration) <= 0 ? dbPollingInterval : sleepDuration; + + SignalMap.awaitAny(loopDuration, eventSignal); + } + } + + Objects.requireNonNull(result); + ctx.checkClosed(); + + if (caller != null) { + var stepResult = + new StepResult( + caller.workflowId(), + caller.stepId(), + stepName, + result.value(), + null, + null, + result.serialization()); + StepsDAO.recordStepResult(ctx, stepResult, startTime.toEpochMilli()); + } + + return SerializationUtil.deserializeValue( + result.value(), result.serialization(), ctx.serializer()); } public static List getAllNotifications(DbContext ctx, String workflowId) diff --git a/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java b/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java index 9caa2a28..0413799c 100644 --- a/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java +++ b/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java @@ -11,7 +11,7 @@ import dev.dbos.transact.context.DBOSContextHolder; import dev.dbos.transact.context.WorkflowInfo; import dev.dbos.transact.database.ExternalState; -import dev.dbos.transact.database.GetWorkflowEventContext; +import dev.dbos.transact.database.GetEventCaller; import dev.dbos.transact.database.Result; import dev.dbos.transact.database.StreamIterator; import dev.dbos.transact.database.SystemDatabase; @@ -471,11 +471,11 @@ public Object getEvent(String workflowId, String key, Duration timeout) { DBOSContext ctx = DBOSContextHolder.get(); if (ctx.isInWorkflow() && !ctx.isInStep()) { - int stepFunctionId = ctx.getAndIncrementFunctionId(); - int timeoutFunctionId = ctx.getAndIncrementFunctionId(); - GetWorkflowEventContext callerCtx = - new GetWorkflowEventContext(ctx.getWorkflowId(), stepFunctionId, timeoutFunctionId); - return systemDatabase.getEvent(workflowId, key, timeout, callerCtx); + int stepId = ctx.getAndIncrementFunctionId(); + int timeoutStepId = ctx.getAndIncrementFunctionId(); + GetEventCaller caller = + new GetEventCaller(ctx.getWorkflowId(), stepId, timeoutStepId); + return systemDatabase.getEvent(workflowId, key, timeout, caller); } return systemDatabase.getEvent(workflowId, key, timeout, null); From 1b04732511155e24e32ba67fd5fb0a496e620e99 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Fri, 15 May 2026 12:50:15 -0700 Subject: [PATCH 14/27] spotless --- .../main/java/dev/dbos/transact/execution/DBOSExecutor.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java b/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java index 0413799c..92149f72 100644 --- a/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java +++ b/transact/src/main/java/dev/dbos/transact/execution/DBOSExecutor.java @@ -473,8 +473,7 @@ public Object getEvent(String workflowId, String key, Duration timeout) { if (ctx.isInWorkflow() && !ctx.isInStep()) { int stepId = ctx.getAndIncrementFunctionId(); int timeoutStepId = ctx.getAndIncrementFunctionId(); - GetEventCaller caller = - new GetEventCaller(ctx.getWorkflowId(), stepId, timeoutStepId); + GetEventCaller caller = new GetEventCaller(ctx.getWorkflowId(), stepId, timeoutStepId); return systemDatabase.getEvent(workflowId, key, timeout, caller); } From 32d937c2e9fba59973bafa137183b6a7b3320393 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Fri, 15 May 2026 14:08:49 -0700 Subject: [PATCH 15/27] NullNotificationSource --- .../java/dev/dbos/transact/DBOSClient.java | 16 ++++- .../transact/database/SystemDatabase.java | 59 +++++++++++++++---- 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/transact/src/main/java/dev/dbos/transact/DBOSClient.java b/transact/src/main/java/dev/dbos/transact/DBOSClient.java index 76092069..01fdfed2 100644 --- a/transact/src/main/java/dev/dbos/transact/DBOSClient.java +++ b/transact/src/main/java/dev/dbos/transact/DBOSClient.java @@ -79,7 +79,7 @@ public T getResult() throws E { * @param password System database credential / password */ public DBOSClient(@NonNull String url, @NonNull String user, @NonNull String password) { - this(url, user, password, null, null); + this(url, user, password, null, null, true); } /** @@ -95,7 +95,7 @@ public DBOSClient( @NonNull String user, @NonNull String password, @Nullable String schema) { - this(url, user, password, schema, null); + this(url, user, password, schema, null, true); } /** @@ -113,8 +113,18 @@ public DBOSClient( @NonNull String password, @Nullable String schema, @Nullable DBOSSerializer serializer) { + this(url, user, password, schema, serializer, true); + } + + public DBOSClient( + @NonNull String url, + @NonNull String user, + @NonNull String password, + @Nullable String schema, + @Nullable DBOSSerializer serializer, + boolean useListenNotify) { this.serializer = serializer; - systemDatabase = new SystemDatabase(url, user, password, schema, serializer); + systemDatabase = new SystemDatabase(url, user, password, schema, serializer, useListenNotify); } /** diff --git a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java index 0ed7e592..d7a8490e 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -11,6 +11,8 @@ import dev.dbos.transact.database.dao.StreamsDAO; import dev.dbos.transact.database.dao.WorkflowDAO; import dev.dbos.transact.database.signal.SignalKey; +import dev.dbos.transact.database.signal.SignalKey.Event; +import dev.dbos.transact.database.signal.SignalKey.Message; import dev.dbos.transact.database.signal.Subscription; import dev.dbos.transact.exceptions.*; import dev.dbos.transact.json.DBOSSerializer; @@ -57,6 +59,25 @@ public interface NotificationSource extends NotifcationRegistry { void close(); } + class NullNotificationSource implements NotificationSource { + + @Override + public Subscription subscribe(Message key) { + return new Subscription(() -> {}); + } + + @Override + public Subscription subscribe(Event key) { + return new Subscription(() -> {}); + } + + @Override + public void start() {} + + @Override + public void close() {} + } + private static final Logger logger = LoggerFactory.getLogger(SystemDatabase.class); public static String sanitizeSchema(String schema) { @@ -84,7 +105,11 @@ private static void validatePostgresDataSource(DataSource dataSource) { } private SystemDatabase( - DataSource dataSource, String schema, boolean created, DBOSSerializer serializer) { + DataSource dataSource, + String schema, + boolean created, + DBOSSerializer serializer, + boolean useListenNotify) { validatePostgresDataSource(dataSource); schema = sanitizeSchema(schema); if (schema.contains("\"")) { @@ -93,26 +118,37 @@ private SystemDatabase( this.ctx = new DbContext(dataSource, schema, serializer, this.closed::get); this.created = created; + try { + useListenNotify = isCockroach(dataSource) ? false : useListenNotify; + } catch (SQLException e) { + logger.error("Failed to determine if dataSouce is CockroachDB", e); + useListenNotify = false; + } - // TODO: NotificationPollingService - notificationSource = new NotificationListenerSource(dataSource); + notificationSource = + useListenNotify ? new NotificationListenerSource(dataSource) : new NullNotificationSource(); } - public SystemDatabase(String url, String user, String password, String schema) { - this(createDataSource(url, user, password), schema, true, null); + public SystemDatabase( + String url, + String user, + String password, + String schema, + DBOSSerializer serializer, + boolean useListenNotify) { + this(createDataSource(url, user, password), schema, true, serializer, useListenNotify); } - public SystemDatabase( - String url, String user, String password, String schema, DBOSSerializer serializer) { - this(createDataSource(url, user, password), schema, true, serializer); + public SystemDatabase(String url, String user, String password, String schema) { + this(createDataSource(url, user, password), schema, true, null, true); } public SystemDatabase(DataSource dataSource, String schema) { - this(dataSource, schema, false, null); + this(dataSource, schema, false, null, true); } public SystemDatabase(DataSource dataSource, String schema, DBOSSerializer serializer) { - this(dataSource, schema, false, serializer); + this(dataSource, schema, false, serializer, true); } public static SystemDatabase create(DBOSConfig config) { @@ -122,7 +158,8 @@ public static SystemDatabase create(DBOSConfig config) { config.dbUser(), config.dbPassword(), config.databaseSchema(), - config.serializer()); + config.serializer(), + config.useListenNotify()); } else { return new SystemDatabase(config.dataSource(), config.databaseSchema(), config.serializer()); } From dce9c3f6fda2b2b8e9bfc9f858432a2358ec1128 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Sun, 17 May 2026 09:12:23 -0700 Subject: [PATCH 16/27] CRDB fixes --- .../dbos/transact/database/dao/QueuesDAO.java | 2 +- .../dbos/transact/client/PgSqlClientTest.java | 48 ++++++++++++------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/transact/src/main/java/dev/dbos/transact/database/dao/QueuesDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/QueuesDAO.java index 0c065766..f1544adf 100644 --- a/transact/src/main/java/dev/dbos/transact/database/dao/QueuesDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/QueuesDAO.java @@ -210,7 +210,7 @@ SELECT executor_id, COUNT(*) as task_count started_at_epoch_ms = ?, workflow_deadline_epoch_ms = CASE WHEN workflow_timeout_ms IS NOT NULL AND workflow_deadline_epoch_ms IS NULL - THEN EXTRACT(epoch FROM NOW()) * 1000 + workflow_timeout_ms + THEN (EXTRACT(epoch FROM now()) * 1000)::bigint + workflow_timeout_ms ELSE workflow_deadline_epoch_ms END WHERE workflow_uuid = ? diff --git a/transact/src/test/java/dev/dbos/transact/client/PgSqlClientTest.java b/transact/src/test/java/dev/dbos/transact/client/PgSqlClientTest.java index a549972e..4c414623 100644 --- a/transact/src/test/java/dev/dbos/transact/client/PgSqlClientTest.java +++ b/transact/src/test/java/dev/dbos/transact/client/PgSqlClientTest.java @@ -198,28 +198,40 @@ String enqueueWorkflow( .toArray(String[]::new); var sql = """ - SELECT dbos.enqueue_workflow( - workflow_name => ?, - class_name => ?, - queue_name => ?, - positional_args => ?, - deduplication_id => ?, - timeout_ms => ?, - deadline_epoch_ms => ?) - """; + SELECT dbos.enqueue_workflow( + ?, -- workflow_name + ?, -- queue_name + ?, -- positional_args + ?, -- named_args + ?, -- class_name + ?, -- config_name + ?, -- workflow_id + ?, -- app_version + ?, -- timeout_ms + ?, -- deadline_epoch_ms + ?, -- deduplication_id + ?, -- priority + ? -- queue_partition_key + ) + """; try (var conn = dataSource.getConnection(); var stmt = conn.prepareCall(sql)) { Long timeoutMS = timeout == null ? null : timeout.toMillis(); Long deadlineMS = deadline == null ? null : deadline.toEpochMilli(); - stmt.setString(1, Objects.requireNonNull(workflowName)); - stmt.setString(2, "ClientServiceImpl"); - stmt.setString(3, "testQueue"); - stmt.setString(5, dedupId); - stmt.setObject(6, timeoutMS, Types.BIGINT); - stmt.setObject(7, deadlineMS, Types.BIGINT); - var argsArray = conn.createArrayOf("json", jsonArgs); - stmt.setObject(4, argsArray); + stmt.setString(1, Objects.requireNonNull(workflowName)); // workflow_name + stmt.setString(2, "testQueue"); // queue_name + stmt.setObject(3, argsArray); // positional_args + stmt.setObject(4, null); // named_args + stmt.setString(5, "ClientServiceImpl"); // class_name + stmt.setObject(6, null); // config_name + stmt.setObject(7, null); // workflow_id + stmt.setObject(8, null); // app_version + stmt.setObject(9, timeoutMS, Types.BIGINT); // timeout_ms + stmt.setObject(10, deadlineMS, Types.BIGINT); // deadline_epoch_ms + stmt.setString(11, dedupId); // deduplication_id + stmt.setObject(12, null); // priority + stmt.setObject(13, null); // queue_partition_key try (ResultSet rs = stmt.executeQuery()) { return rs.next() ? rs.getString(1) : null; } finally { @@ -232,7 +244,7 @@ void sendMessage(String destinationId, Object message, String topic, String idem throws SQLException, JsonProcessingException { String jsonMessage = MAPPER.writeValueAsString(Objects.requireNonNull(message)); - var sql = "SELECT dbos.send_message(?, ?::json, topic => ?, idempotency_key => ?)"; + var sql = "SELECT dbos.send_message(?, ?::json, ?, ?)"; try (var conn = dataSource.getConnection(); var stmt = conn.prepareCall(sql)) { stmt.setString(1, Objects.requireNonNull(destinationId)); From 022b0715ae31eb22052c3a6d826d567c2e34b2f9 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Sun, 17 May 2026 11:36:39 -0700 Subject: [PATCH 17/27] don't automatically create database for tests --- .../dev/dbos/transact/config/ConfigTest.java | 8 ++++ .../migrations/MigrationManagerTest.java | 39 +++++++++++++++++++ .../txstep/JdbcStepFactoryInitTest.java | 6 +++ .../transact/txstep/JdbcStepFactoryTest.java | 2 + .../dev/dbos/transact/utils/PgContainer.java | 15 +++---- 5 files changed, 60 insertions(+), 10 deletions(-) diff --git a/transact/src/test/java/dev/dbos/transact/config/ConfigTest.java b/transact/src/test/java/dev/dbos/transact/config/ConfigTest.java index 68efe1cb..25304c96 100644 --- a/transact/src/test/java/dev/dbos/transact/config/ConfigTest.java +++ b/transact/src/test/java/dev/dbos/transact/config/ConfigTest.java @@ -276,6 +276,10 @@ public void calcAppVersionNotMatch() throws Exception { @Test public void configPGSimpleDataSource() throws Exception { + // dbos.launch doesn't attempt to create the database when + // receiving data source in DBOSConfig + pgContainer.createDatabase(); + var jdbcUrl = pgContainer.jdbcUrl(); assertTrue(jdbcUrl.startsWith("jdbc:")); @@ -316,6 +320,10 @@ public void configPGSimpleDataSource() throws Exception { @Test public void configHikariDataSource() throws Exception { + // need to create database since we are connecting + // the data source prior to dbos launch + pgContainer.createDatabase(); + var poolName = "dbos-configDataSource"; HikariConfig hikariConfig = new HikariConfig(); diff --git a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java index 05c0e33e..2651341f 100644 --- a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java +++ b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java @@ -12,6 +12,7 @@ import java.sql.Connection; import java.sql.DatabaseMetaData; +import java.sql.DriverManager; import java.sql.ResultSet; import java.util.ArrayList; @@ -144,6 +145,39 @@ void testRunMigrations_customSchema(String schema) throws Exception { } } + @Test + void testRunMigrations_CreatesDatabaseIfNotExists() throws Exception { + var dbosConfig = pgContainer.dbosConfig(); + var pair = MigrationManager.extractDbAndPostgresUrl(pgContainer.jdbcUrl()); + + // Verify the database does not exist before running migrations + try (var conn = + DriverManager.getConnection( + pair.url(), pgContainer.username(), pgContainer.password()); + var stmt = conn.prepareStatement("SELECT 1 FROM pg_database WHERE datname = ?")) { + stmt.setString(1, pair.database()); + try (var rs = stmt.executeQuery()) { + assertFalse( + rs.next(), + "Database '%s' should not exist before runMigrations".formatted(pair.database())); + } + } + + MigrationManager.runMigrations(dbosConfig); + + // Verify the database now exists after running migrations + try (var conn = + DriverManager.getConnection( + pair.url(), pgContainer.username(), pgContainer.password()); + var stmt = conn.prepareStatement("SELECT 1 FROM pg_database WHERE datname = ?")) { + stmt.setString(1, pair.database()); + try (var rs = stmt.executeQuery()) { + assertTrue( + rs.next(), "Database '%s' should exist after runMigrations".formatted(pair.database())); + } + } + } + @Test void testRunMigrations_IsIdempotent() throws Exception { @@ -204,6 +238,11 @@ public void extractDbAndPostgresUrl() { @Test void testOriginalMigration1ThenAllMigrations_NotificationsPrimaryKey() throws Exception { + + // need to create database since we are connecting + // the data source prior to dbos launch + pgContainer.createDatabase(); + try (Connection conn = dataSource.getConnection()) { // Ensure schema and migration table exist MigrationManager.ensureDbosSchema(conn, Constants.DB_SCHEMA); diff --git a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java index 7458e8df..a1b1c5a9 100644 --- a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java +++ b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java @@ -17,11 +17,17 @@ import java.util.Objects; import org.junit.jupiter.api.AutoClose; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class JdbcStepFactoryInitTest { @AutoClose final PgContainer pgContainer = new PgContainer(); + @BeforeEach + void beforeEach() throws SQLException { + pgContainer.createDatabase(); + } + static boolean validateSchema(Connection conn, String schema) throws SQLException { Objects.requireNonNull(schema); try (var rs = conn.getMetaData().getSchemas()) { diff --git a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java index 61db3688..afa65de4 100644 --- a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java +++ b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java @@ -123,6 +123,8 @@ public class JdbcStepFactoryTest { @BeforeEach void beforeEach() throws SQLException { + pgContainer.createDatabase(); + dbosConfig = pgContainer.dbosConfig(); dataSource = pgContainer.dataSource(); diff --git a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java index a466d222..68ebd6c7 100644 --- a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -3,9 +3,9 @@ import dev.dbos.transact.DBOSClient; import dev.dbos.transact.config.DBOSConfig; import dev.dbos.transact.database.SystemDatabase; +import dev.dbos.transact.migrations.MigrationManager; import java.sql.DriverManager; -import java.sql.SQLException; import java.util.ArrayList; import java.util.Objects; import java.util.UUID; @@ -61,15 +61,6 @@ public PgContainer() { pgContainer = acquire(); dbName = "test_" + UUID.randomUUID().toString().replace("-", ""); jdbcUrl = pgContainer.getJdbcUrl().replaceFirst("/[^/]+$", "/" + dbName); - - try (var conn = - DriverManager.getConnection( - pgContainer.getJdbcUrl(), pgContainer.getUsername(), pgContainer.getPassword()); - var stmt = conn.createStatement()) { - stmt.execute("CREATE DATABASE " + dbName); - } catch (SQLException e) { - throw new RuntimeException(e); - } } @Override @@ -114,4 +105,8 @@ public HikariDataSource dataSource() { public DBOSClient dbosClient() { return new DBOSClient(jdbcUrl(), username(), password()); } + + public void createDatabase() { + MigrationManager.createDatabaseIfNotExists(jdbcUrl(), username(), password()); + } } From 7ea9f74621b7e7838b3d8f475bd83703471a42d9 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Sun, 17 May 2026 21:42:53 -0700 Subject: [PATCH 18/27] CRDB compat --- .../migrations/CockroachMigrationTest.java | 15 +- .../migrations/MigrationManagerTest.java | 33 +++-- .../txstep/JdbcStepFactoryInitTest.java | 2 +- .../transact/txstep/JdbcStepFactoryTest.java | 2 + .../dev/dbos/transact/utils/PgContainer.java | 129 ++++++++++++------ .../test/resources/junit-platform.properties | 2 + 6 files changed, 120 insertions(+), 63 deletions(-) diff --git a/transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java b/transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java index 69a61739..9a56e04c 100644 --- a/transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java +++ b/transact/src/test/java/dev/dbos/transact/migrations/CockroachMigrationTest.java @@ -5,25 +5,22 @@ import dev.dbos.transact.config.DBOSConfig; import dev.dbos.transact.database.SystemDatabase; +import dev.dbos.transact.utils.PgContainer; import java.sql.Connection; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.parallel.Isolated; -import org.testcontainers.cockroachdb.CockroachContainer; -import org.testcontainers.postgresql.PostgreSQLContainer; /** * Tests that LISTEN/NOTIFY functions and triggers are created on PG and omitted on CRDB. Each test * spins up its own container so these tests are always tied to the specified DB type regardless of * how PgContainer is configured for the rest of the test suite. */ -@Isolated class CockroachMigrationTest { @Test void testPg_isCockroachReturnsFalse() throws Exception { - try (var pg = new PostgreSQLContainer("postgres:latest")) { + try (var pg = PgContainer.getPG()) { pg.start(); try (var ds = SystemDatabase.createDataSource(pg.getJdbcUrl(), pg.getUsername(), pg.getPassword())) { @@ -37,7 +34,7 @@ void testPg_isCockroachReturnsFalse() throws Exception { @Test void testCrdb_isCockroachReturnsTrue() throws Exception { - try (var crdb = new CockroachContainer("cockroachdb/cockroach:latest")) { + try (var crdb = PgContainer.getCRDB()) { crdb.start(); try (var ds = SystemDatabase.createDataSource( @@ -52,7 +49,7 @@ void testCrdb_isCockroachReturnsTrue() throws Exception { @Test void testPg_notifyFunctionsAndTriggersPresent() throws Exception { - try (var pg = new PostgreSQLContainer("postgres:latest")) { + try (var pg = PgContainer.getPG()) { pg.start(); try (var ds = SystemDatabase.createDataSource(pg.getJdbcUrl(), pg.getUsername(), pg.getPassword())) { @@ -72,7 +69,7 @@ void testPg_notifyFunctionsAndTriggersPresent() throws Exception { @Test void testPg_noListenNotify_notifyFunctionsAndTriggersAbsent() throws Exception { - try (var pg = new PostgreSQLContainer("postgres:latest")) { + try (var pg = PgContainer.getPG()) { pg.start(); try (var ds = SystemDatabase.createDataSource(pg.getJdbcUrl(), pg.getUsername(), pg.getPassword())) { @@ -97,7 +94,7 @@ void testPg_noListenNotify_notifyFunctionsAndTriggersAbsent() throws Exception { @Test void testCrdb_notifyFunctionsAndTriggersAbsent() throws Exception { - try (var crdb = new CockroachContainer("cockroachdb/cockroach:latest")) { + try (var crdb = PgContainer.getCRDB()) { crdb.start(); try (var ds = SystemDatabase.createDataSource( diff --git a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java index 2651341f..254a5883 100644 --- a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java +++ b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java @@ -18,6 +18,7 @@ import com.zaxxer.hikari.HikariDataSource; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.AutoClose; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -39,17 +40,20 @@ class MigrationManagerTest { "workflow_status" }; - // Expected functions after migrations - static final String[] EXPECTED_FUNCTIONS = { - "notifications_function", "workflow_events_function", "enqueue_workflow", "send_message" + // Expected functions after migrations (always present) + static final String[] EXPECTED_FUNCTIONS = {"enqueue_workflow", "send_message"}; + + // Expected LISTEN/NOTIFY functions after migrations (PG only, absent on CRDB) + static final String[] EXPECTED_NOTIFY_FUNCTIONS = { + "notifications_function", "workflow_events_function" }; - // Expected LISTEN/NOTIFY triggers after migrations (only when useListenNotify=true) + // Expected LISTEN/NOTIFY triggers after migrations (PG only, absent on CRDB) static final String[] EXPECTED_NOTIFY_TRIGGERS = { "dbos_notifications_trigger", "dbos_workflow_events_trigger" }; - @AutoClose final PgContainer pgContainer = new PgContainer(); + @AutoClose final PgContainer pgContainer = PgContainer.createFresh(); @AutoClose HikariDataSource dataSource; @BeforeEach @@ -75,8 +79,13 @@ void testRunMigrations_CreatesTables() throws Exception { for (String function : EXPECTED_FUNCTIONS) { assertFunctionExists(metaData, function); } - for (String trigger : EXPECTED_NOTIFY_TRIGGERS) { - assertTriggerExists(conn, trigger); + if (!PgContainer.USE_COCKROACH_DB) { + for (String function : EXPECTED_NOTIFY_FUNCTIONS) { + assertFunctionExists(metaData, function); + } + for (String trigger : EXPECTED_NOTIFY_TRIGGERS) { + assertTriggerExists(conn, trigger); + } } var migrations = new ArrayList<>(MigrationManager.getMigrations(Constants.DB_SCHEMA, true)); @@ -135,8 +144,13 @@ void testRunMigrations_customSchema(String schema) throws Exception { for (String function : EXPECTED_FUNCTIONS) { assertFunctionExists(metaData, function, schema); } - for (String trigger : EXPECTED_NOTIFY_TRIGGERS) { - assertTriggerExists(conn, trigger, schema); + if (!PgContainer.USE_COCKROACH_DB) { + for (String function : EXPECTED_NOTIFY_FUNCTIONS) { + assertFunctionExists(metaData, function, schema); + } + for (String trigger : EXPECTED_NOTIFY_TRIGGERS) { + assertTriggerExists(conn, trigger, schema); + } } var migrations = new ArrayList<>(MigrationManager.getMigrations(schema, true)); @@ -238,6 +252,7 @@ public void extractDbAndPostgresUrl() { @Test void testOriginalMigration1ThenAllMigrations_NotificationsPrimaryKey() throws Exception { + Assumptions.assumeFalse(PgContainer.USE_COCKROACH_DB, "PG-only migration history test"); // need to create database since we are connecting // the data source prior to dbos launch diff --git a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java index a1b1c5a9..3a9dafbd 100644 --- a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java +++ b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryInitTest.java @@ -21,7 +21,7 @@ import org.junit.jupiter.api.Test; public class JdbcStepFactoryInitTest { - @AutoClose final PgContainer pgContainer = new PgContainer(); + @AutoClose final PgContainer pgContainer = PgContainer.createFresh(); @BeforeEach void beforeEach() throws SQLException { diff --git a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java index afa65de4..e2cba816 100644 --- a/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java +++ b/transact/src/test/java/dev/dbos/transact/txstep/JdbcStepFactoryTest.java @@ -130,6 +130,8 @@ void beforeEach() throws SQLException { try (var conn = dataSource.getConnection(); var stmt = conn.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS greetings"); + stmt.execute("DROP TABLE IF EXISTS dbos.tx_step_outputs"); stmt.execute( "CREATE TABLE greetings(name text NOT NULL, greet_count integer DEFAULT 0, PRIMARY KEY(name))"); } diff --git a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java index 68ebd6c7..8d10dd42 100644 --- a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -5,74 +5,115 @@ import dev.dbos.transact.database.SystemDatabase; import dev.dbos.transact.migrations.MigrationManager; +import java.sql.Connection; import java.sql.DriverManager; -import java.util.ArrayList; +import java.sql.SQLException; import java.util.Objects; -import java.util.UUID; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.Semaphore; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import com.zaxxer.hikari.HikariDataSource; +import org.testcontainers.cockroachdb.CockroachContainer; +import org.testcontainers.containers.JdbcDatabaseContainer; import org.testcontainers.postgresql.PostgreSQLContainer; +// TODO: custom junit.jupiter.execution.parallel.config.strategy / dynamic.factor / +// fixed.parallelism = 2 reader + public class PgContainer implements AutoCloseable { - private static final int SIZE = Runtime.getRuntime().availableProcessors(); - private static final BlockingQueue POOL = new ArrayBlockingQueue<>(SIZE); - private static final Semaphore PERMITS = new Semaphore(SIZE); - - static { - Runtime.getRuntime() - .addShutdownHook( - new Thread( - () -> { - var containers = new ArrayList(); - POOL.drainTo(containers); - containers.forEach(PostgreSQLContainer::stop); - })); - } - - static PostgreSQLContainer acquire() { - try { - PERMITS.acquire(); - var container = POOL.poll(); - if (container == null) { - container = new PostgreSQLContainer("postgres:latest"); - container.start(); + public static final boolean USE_COCKROACH_DB = + Boolean.parseBoolean(System.getenv("DBOS_TEST_USE_COCKROACH_DB")); + private static final String DB_NAME = "dbos_test_db"; + + private static final Queue> POOL = new ConcurrentLinkedQueue<>(); + + public static PostgreSQLContainer getPG() { + return new PostgreSQLContainer("postgres:latest"); + } + + public static CockroachContainer getCRDB() { + return new CockroachContainer("cockroachdb/cockroach:latest"); + } + + private static JdbcDatabaseContainer containerSupplier() { + var container = USE_COCKROACH_DB ? getCRDB() : getPG(); + container.start(); + return container; + } + + static JdbcDatabaseContainer acquire() { + var container = POOL.poll(); + if (container != null) { + var jdbcUrl = container.getJdbcUrl().replaceFirst("/[^/]+$", "/" + DB_NAME); + try (var conn = + DriverManager.getConnection(jdbcUrl, container.getUsername(), container.getPassword())) { + truncateDbosTables(conn); + } catch (SQLException e) { + throw new RuntimeException(e); } return container; - } catch (InterruptedException e) { - throw new RuntimeException(e); } + container = containerSupplier(); + var jdbcUrl = container.getJdbcUrl().replaceFirst("/[^/]+$", "/" + DB_NAME); + + MigrationManager.runMigrations( + jdbcUrl, container.getUsername(), container.getPassword(), "dbos", true); + return container; + + } - static void release(PostgreSQLContainer c) { + static void release(JdbcDatabaseContainer c) { POOL.offer(c); - PERMITS.release(); } - private final PostgreSQLContainer pgContainer; + public static void truncateDbosTables(Connection conn) throws SQLException { + // truncate the DBOS tables from the test DB before returning to the pool + var truncate = + """ + TRUNCATE TABLE + "dbos".workflow_status, + "dbos".operation_outputs, + "dbos".workflow_events, + "dbos".workflow_events_history, + "dbos".notifications, + "dbos".event_dispatch_kv, + "dbos".streams, + "dbos".application_versions, + "dbos".workflow_schedules + CASCADE + """; + try (var stmt = conn.createStatement()) { + stmt.execute(truncate); + } + } + + private final JdbcDatabaseContainer pgContainer; private final String jdbcUrl; - private final String dbName; + private final boolean pooled; public PgContainer() { - // take a container from the pool and create a new database for it - pgContainer = acquire(); - dbName = "test_" + UUID.randomUUID().toString().replace("-", ""); - jdbcUrl = pgContainer.getJdbcUrl().replaceFirst("/[^/]+$", "/" + dbName); + this(false); + } + + private PgContainer(boolean requireFresh) { + pooled = !requireFresh; + pgContainer = pooled ? acquire() : containerSupplier(); + jdbcUrl = pgContainer.getJdbcUrl().replaceFirst("/[^/]+$", "/" + DB_NAME); + } + + public static PgContainer createFresh() { + return new PgContainer(true); } @Override public void close() throws Exception { - // drop the database we created and return the container too the pool - var _jdbcUrl = pgContainer.getJdbcUrl(); - try (var conn = DriverManager.getConnection(_jdbcUrl, username(), password()); - var stmt = conn.createStatement()) { - var sql = "DROP DATABASE IF EXISTS %s WITH (FORCE)".formatted(dbName); - stmt.execute(sql); + if (pooled) { + release(pgContainer); + } else { + pgContainer.close(); } - release(pgContainer); } public String jdbcUrl() { diff --git a/transact/src/test/resources/junit-platform.properties b/transact/src/test/resources/junit-platform.properties index f67fe92f..c5f0b3bc 100644 --- a/transact/src/test/resources/junit-platform.properties +++ b/transact/src/test/resources/junit-platform.properties @@ -2,5 +2,7 @@ junit.jupiter.execution.parallel.enabled = true junit.jupiter.execution.parallel.mode.default = concurrent junit.jupiter.execution.parallel.mode.classes.default = concurrent junit.jupiter.execution.parallel.config.strategy = dynamic +#junit.jupiter.execution.parallel.config.strategy = fixed junit.jupiter.execution.parallel.config.dynamic.factor = 1.0 +junit.jupiter.execution.parallel.config.fixed.parallelism = 2 junit.jupiter.execution.timeout.default = 2 m From d93cb21417764100644d9f9c3370af21b9a8f87c Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Sun, 17 May 2026 21:43:56 -0700 Subject: [PATCH 19/27] spotless --- .../test/java/dev/dbos/transact/utils/PgContainer.java | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java index 8d10dd42..813fa623 100644 --- a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -54,14 +54,12 @@ static JdbcDatabaseContainer acquire() { } return container; } - container = containerSupplier(); + container = containerSupplier(); var jdbcUrl = container.getJdbcUrl().replaceFirst("/[^/]+$", "/" + DB_NAME); - MigrationManager.runMigrations( - jdbcUrl, container.getUsername(), container.getPassword(), "dbos", true); - return container; - - + MigrationManager.runMigrations( + jdbcUrl, container.getUsername(), container.getPassword(), "dbos", true); + return container; } static void release(JdbcDatabaseContainer c) { From e6bd3ea196f4b5ee0a0644c03a22e5f87de18220 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Sun, 17 May 2026 22:09:57 -0700 Subject: [PATCH 20/27] CrdbParallelExecutionConfigurationStrategy --- .github/workflows/on_pr.yml | 2 + .github/workflows/on_push.yml | 2 + .github/workflows/test_crdb.yml | 47 +++++++++++++++++ gradle/libs.versions.toml | 1 + .../dev/dbos/transact/cli/PgContainer.java | 2 +- .../dev/dbos/transact/utils/PgContainer.java | 2 +- .../dev/dbos/transact/utils/PgContainer.java | 2 +- .../dev/dbos/transact/spring/PgContainer.java | 2 +- transact/build.gradle.kts | 1 + ...arallelExecutionConfigurationStrategy.java | 50 +++++++++++++++++++ .../dev/dbos/transact/utils/PgContainer.java | 3 -- .../test/resources/junit-platform.properties | 5 +- 12 files changed, 109 insertions(+), 10 deletions(-) create mode 100644 .github/workflows/test_crdb.yml create mode 100644 transact/src/test/java/dev/dbos/transact/utils/CrdbParallelExecutionConfigurationStrategy.java diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 714de1f6..d3ffff5f 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -14,5 +14,7 @@ on: jobs: test: uses: ./.github/workflows/test.yml + test-crdb: + uses: ./.github/workflows/test_crdb.yml test-demo-apps: uses: ./.github/workflows/test_demo_apps.yml \ No newline at end of file diff --git a/.github/workflows/on_push.yml b/.github/workflows/on_push.yml index c6d155ca..668fa3c9 100644 --- a/.github/workflows/on_push.yml +++ b/.github/workflows/on_push.yml @@ -9,6 +9,8 @@ on: jobs: test: uses: ./.github/workflows/test.yml + test-crdb: + uses: ./.github/workflows/test_crdb.yml publish: needs: test uses: ./.github/workflows/publish.yml diff --git a/.github/workflows/test_crdb.yml b/.github/workflows/test_crdb.yml new file mode 100644 index 00000000..bb032b67 --- /dev/null +++ b/.github/workflows/test_crdb.yml @@ -0,0 +1,47 @@ +name: Test (CockroachDB) + +on: + workflow_call: + workflow_dispatch: + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + fetch-depth: 0 # fetch-depth 0 needed for version calculation + + - name: Set up JDK temurin 17 + uses: actions/setup-java@v5 + with: + java-version: '17' + distribution: 'temurin' + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v5 + + - name: Run tests + run: ./gradlew clean build + env: + PGPASSWORD: dbos + JDKVERSION: '17' + DBOS_TEST_USE_COCKROACH_DB: 'true' + + - name: Test Summary + uses: test-summary/action@v2 + with: + paths: "transact/build/test-results/test/TEST-*.xml" + show: "fail, skip" + if: always() + + - name: Upload test results + uses: actions/upload-artifact@v7 + if: always() + with: + name: test-results-crdb-temurin-17 + path: | + transact/build/reports/tests/ + transact/build/test-results/ diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 154d52bc..95d4cf67 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -51,6 +51,7 @@ jspecify = { module = "org.jspecify:jspecify", version.ref = "jspecify" } junit-bom = { module = "org.junit:junit-bom", version.ref = "junit" } junit-jupiter = { module = "org.junit.jupiter:junit-jupiter" } junit-pioneer = { module = "org.junit-pioneer:junit-pioneer", version.ref = "junit-pioneer" } +junit-platform-engine = { module = "org.junit.platform:junit-platform-engine" } junit-platform-launcher = { module = "org.junit.platform:junit-platform-launcher" } kryo = { module = "com.esotericsoftware:kryo", version.ref = "kryo" } logback-classic = { module = "ch.qos.logback:logback-classic", version.ref = "logback" } diff --git a/transact-cli/src/test/java/dev/dbos/transact/cli/PgContainer.java b/transact-cli/src/test/java/dev/dbos/transact/cli/PgContainer.java index 3db061be..0b881a8b 100644 --- a/transact-cli/src/test/java/dev/dbos/transact/cli/PgContainer.java +++ b/transact-cli/src/test/java/dev/dbos/transact/cli/PgContainer.java @@ -34,7 +34,7 @@ static PostgreSQLContainer acquire() { PERMITS.acquire(); var container = POOL.poll(); if (container == null) { - container = new PostgreSQLContainer("postgres:18"); + container = new PostgreSQLContainer("postgres:latest"); container.start(); } return container; diff --git a/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java index 2c1431dd..a466d222 100644 --- a/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -38,7 +38,7 @@ static PostgreSQLContainer acquire() { PERMITS.acquire(); var container = POOL.poll(); if (container == null) { - container = new PostgreSQLContainer("postgres:18"); + container = new PostgreSQLContainer("postgres:latest"); container.start(); } return container; diff --git a/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java index 2c1431dd..a466d222 100644 --- a/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -38,7 +38,7 @@ static PostgreSQLContainer acquire() { PERMITS.acquire(); var container = POOL.poll(); if (container == null) { - container = new PostgreSQLContainer("postgres:18"); + container = new PostgreSQLContainer("postgres:latest"); container.start(); } return container; diff --git a/transact-spring-boot-starter/src/test/java/dev/dbos/transact/spring/PgContainer.java b/transact-spring-boot-starter/src/test/java/dev/dbos/transact/spring/PgContainer.java index c8a2fa18..c64c739a 100644 --- a/transact-spring-boot-starter/src/test/java/dev/dbos/transact/spring/PgContainer.java +++ b/transact-spring-boot-starter/src/test/java/dev/dbos/transact/spring/PgContainer.java @@ -32,7 +32,7 @@ private static PostgreSQLContainer acquire() { PERMITS.acquire(); var container = POOL.poll(); if (container == null) { - container = new PostgreSQLContainer("postgres:18"); + container = new PostgreSQLContainer("postgres:latest"); container.start(); } return container; diff --git a/transact/build.gradle.kts b/transact/build.gradle.kts index 4eb6ac7e..c5016293 100644 --- a/transact/build.gradle.kts +++ b/transact/build.gradle.kts @@ -39,6 +39,7 @@ dependencies { testImplementation(libs.junit.jupiter) testImplementation(libs.junit.pioneer) testImplementation(libs.system.stubs.jupiter) + testImplementation(libs.junit.platform.engine) testRuntimeOnly(libs.junit.platform.launcher) testImplementation(libs.java.websocket) diff --git a/transact/src/test/java/dev/dbos/transact/utils/CrdbParallelExecutionConfigurationStrategy.java b/transact/src/test/java/dev/dbos/transact/utils/CrdbParallelExecutionConfigurationStrategy.java new file mode 100644 index 00000000..8c45d98d --- /dev/null +++ b/transact/src/test/java/dev/dbos/transact/utils/CrdbParallelExecutionConfigurationStrategy.java @@ -0,0 +1,50 @@ +package dev.dbos.transact.utils; + +import org.junit.platform.engine.ConfigurationParameters; +import org.junit.platform.engine.support.hierarchical.DefaultParallelExecutionConfigurationStrategy; +import org.junit.platform.engine.support.hierarchical.ParallelExecutionConfiguration; +import org.junit.platform.engine.support.hierarchical.ParallelExecutionConfigurationStrategy; + +public class CrdbParallelExecutionConfigurationStrategy + implements ParallelExecutionConfigurationStrategy { + + @Override + public ParallelExecutionConfiguration createConfiguration( + ConfigurationParameters configurationParameters) { + if (PgContainer.USE_COCKROACH_DB) { + int parallelism = Runtime.getRuntime().availableProcessors() >= 8 ? 2 : 1; + return fixedConfig(parallelism); + } + return DefaultParallelExecutionConfigurationStrategy.DYNAMIC.createConfiguration( + configurationParameters); + } + + private static ParallelExecutionConfiguration fixedConfig(int parallelism) { + return new ParallelExecutionConfiguration() { + @Override + public int getParallelism() { + return parallelism; + } + + @Override + public int getMinimumRunnable() { + return parallelism; + } + + @Override + public int getMaxPoolSize() { + return parallelism + 256; + } + + @Override + public int getCorePoolSize() { + return parallelism; + } + + @Override + public int getKeepAliveSeconds() { + return 30; + } + }; + } +} diff --git a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java index 813fa623..54bbe876 100644 --- a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -17,9 +17,6 @@ import org.testcontainers.containers.JdbcDatabaseContainer; import org.testcontainers.postgresql.PostgreSQLContainer; -// TODO: custom junit.jupiter.execution.parallel.config.strategy / dynamic.factor / -// fixed.parallelism = 2 reader - public class PgContainer implements AutoCloseable { public static final boolean USE_COCKROACH_DB = diff --git a/transact/src/test/resources/junit-platform.properties b/transact/src/test/resources/junit-platform.properties index c5f0b3bc..bbef1c69 100644 --- a/transact/src/test/resources/junit-platform.properties +++ b/transact/src/test/resources/junit-platform.properties @@ -1,8 +1,7 @@ junit.jupiter.execution.parallel.enabled = true junit.jupiter.execution.parallel.mode.default = concurrent junit.jupiter.execution.parallel.mode.classes.default = concurrent -junit.jupiter.execution.parallel.config.strategy = dynamic -#junit.jupiter.execution.parallel.config.strategy = fixed +junit.jupiter.execution.parallel.config.strategy = custom +junit.jupiter.execution.parallel.config.custom.class = dev.dbos.transact.utils.CrdbParallelExecutionConfigurationStrategy junit.jupiter.execution.parallel.config.dynamic.factor = 1.0 -junit.jupiter.execution.parallel.config.fixed.parallelism = 2 junit.jupiter.execution.timeout.default = 2 m From f942974f7289c5b541061eaa4bcd46a7a3ad59c3 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Sun, 17 May 2026 22:28:56 -0700 Subject: [PATCH 21/27] cleanup --- .github/workflows/test_demo_apps.yml | 2 +- .../src/main/java/dev/dbos/transact/cli/MigrateCommand.java | 2 +- .../java/dev/dbos/transact/invocation/CustomSchemaTest.java | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_demo_apps.yml b/.github/workflows/test_demo_apps.yml index a5a01c00..eae52a18 100644 --- a/.github/workflows/test_demo_apps.yml +++ b/.github/workflows/test_demo_apps.yml @@ -5,7 +5,7 @@ on: workflow_dispatch: jobs: - publish: + publish-maven-local: runs-on: ubuntu-latest outputs: version: ${{ steps.version.outputs.version }} diff --git a/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java b/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java index b5b90040..6a00da13 100644 --- a/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java +++ b/transact-cli/src/main/java/dev/dbos/transact/cli/MigrateCommand.java @@ -38,7 +38,7 @@ public Integer call() throws Exception { out.format(" System Database: %s\n", dbOptions.url()); out.format(" System Database User: %s\n", dbOptions.user()); - // TODO: real fix + // TODO: add option for useListenNotify MigrationManager.runMigrations( dbOptions.url(), dbOptions.user(), dbOptions.password(), dbOptions.schema(), true); grantDBOSSchemaPermissions(out, dbOptions.schema()); diff --git a/transact/src/test/java/dev/dbos/transact/invocation/CustomSchemaTest.java b/transact/src/test/java/dev/dbos/transact/invocation/CustomSchemaTest.java index 80f9d07a..482bafc7 100644 --- a/transact/src/test/java/dev/dbos/transact/invocation/CustomSchemaTest.java +++ b/transact/src/test/java/dev/dbos/transact/invocation/CustomSchemaTest.java @@ -21,7 +21,7 @@ import org.junit.jupiter.api.Test; public class CustomSchemaTest { - @AutoClose final PgContainer pgContainer = new PgContainer(); + @AutoClose final PgContainer pgContainer = PgContainer.createFresh(); private static final String schema = "F8nny_sCHem@-n@m3"; @AutoClose DBOS dbos; private HawkService proxy; From 97e0bb2d732549f6fff50af15ee8a054224c88c8 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Sun, 17 May 2026 22:30:51 -0700 Subject: [PATCH 22/27] fix test_demo_apps --- .github/workflows/test_demo_apps.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_demo_apps.yml b/.github/workflows/test_demo_apps.yml index eae52a18..7b3fa4ab 100644 --- a/.github/workflows/test_demo_apps.yml +++ b/.github/workflows/test_demo_apps.yml @@ -41,7 +41,7 @@ jobs: path: ~/.m2/repository/dev/dbos test-demo-apps: - needs: publish + needs: publish-maven-local runs-on: ubuntu-latest env: From 2d6f1f18d1065c2a3081c33642d4c1e7bc01d947 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Sun, 17 May 2026 22:38:17 -0700 Subject: [PATCH 23/27] revert test demo apps changes --- .github/workflows/test_demo_apps.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_demo_apps.yml b/.github/workflows/test_demo_apps.yml index 7b3fa4ab..a5a01c00 100644 --- a/.github/workflows/test_demo_apps.yml +++ b/.github/workflows/test_demo_apps.yml @@ -5,7 +5,7 @@ on: workflow_dispatch: jobs: - publish-maven-local: + publish: runs-on: ubuntu-latest outputs: version: ${{ steps.version.outputs.version }} @@ -41,7 +41,7 @@ jobs: path: ~/.m2/repository/dev/dbos test-demo-apps: - needs: publish-maven-local + needs: publish runs-on: ubuntu-latest env: From e0a5ab277b5061f78543075467a78af0d85c88ae Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Sun, 17 May 2026 22:47:19 -0700 Subject: [PATCH 24/27] use JDK 25 for crdb CI --- .github/workflows/test_crdb.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_crdb.yml b/.github/workflows/test_crdb.yml index bb032b67..27060012 100644 --- a/.github/workflows/test_crdb.yml +++ b/.github/workflows/test_crdb.yml @@ -14,10 +14,10 @@ jobs: with: fetch-depth: 0 # fetch-depth 0 needed for version calculation - - name: Set up JDK temurin 17 + - name: Set up JDK temurin 25 uses: actions/setup-java@v5 with: - java-version: '17' + java-version: '25' distribution: 'temurin' - name: Setup Gradle From 23adadad9a31233bc64a9ab58143072d778021a9 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Sun, 17 May 2026 23:04:56 -0700 Subject: [PATCH 25/27] Assumptions for virtual thread pool --- .../dbos/transact/execution/DBOSExecutorTest.java | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/transact/src/test/java/dev/dbos/transact/execution/DBOSExecutorTest.java b/transact/src/test/java/dev/dbos/transact/execution/DBOSExecutorTest.java index 36f8f6f8..f5b93f30 100644 --- a/transact/src/test/java/dev/dbos/transact/execution/DBOSExecutorTest.java +++ b/transact/src/test/java/dev/dbos/transact/execution/DBOSExecutorTest.java @@ -16,12 +16,12 @@ import java.util.List; import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.AutoClose; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.DisabledForJreRange; import org.junit.jupiter.api.condition.EnabledForJreRange; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.JRE; import org.junitpioneer.jupiter.RetryingTest; import org.slf4j.Logger; @@ -69,8 +69,9 @@ public void virtualThreadPoolJava21() throws Exception { } @Test - @EnabledIfEnvironmentVariable(named = "JDKVERSION", matches = "21|25") - public void virtualThreadPoolJDK21And25() throws Exception { + public void virtualThreadPoolJDK21OrLater() throws Exception { + int jdk = Runtime.version().feature(); + Assumptions.assumeTrue(jdk >= 21, "Skipping: requires JDK 21 or later, got " + jdk); try (var dbos = new DBOS(dbosConfig)) { dbos.launch(); assertFalse(DBOSTestAccess.getDbosExecutor(dbos).usingThreadPoolExecutor()); @@ -87,8 +88,9 @@ public void threadPoolJava17() throws Exception { } @Test - @EnabledIfEnvironmentVariable(named = "JDKVERSION", matches = "17|17\\..*") - public void threadPoolJDK17() throws Exception { + public void threadPoolJDK20OrEarlier() throws Exception { + int jdk = Runtime.version().feature(); + Assumptions.assumeTrue(jdk < 21, "Skipping: requires JDK 20 or earlier, got " + jdk); try (var dbos = new DBOS(dbosConfig)) { dbos.launch(); assertTrue(DBOSTestAccess.getDbosExecutor(dbos).usingThreadPoolExecutor()); From 88d4dc6dad90acbe0feea2ce5dd93e76ba5167a1 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Mon, 18 May 2026 08:40:05 -0700 Subject: [PATCH 26/27] copilot feedback --- .github/workflows/test.yml | 1 - .github/workflows/test_crdb.yml | 3 +- .../dev/dbos/transact/cli/PgContainer.java | 2 +- .../dev/dbos/transact/utils/PgContainer.java | 2 +- .../dev/dbos/transact/utils/PgContainer.java | 2 +- .../dev/dbos/transact/spring/PgContainer.java | 2 +- .../transact/database/SystemDatabase.java | 6 ++-- .../database/dao/NotificationsDAO.java | 6 ++-- .../transact/database/signal/SignalMap.java | 7 +++- .../database/signal/Subscription.java | 2 ++ .../migrations/MigrationManagerTest.java | 34 +++++++++---------- .../dev/dbos/transact/utils/PgContainer.java | 4 +-- 12 files changed, 38 insertions(+), 33 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 99f1e764..601464c8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,7 +39,6 @@ jobs: run: ./gradlew clean build env: PGPASSWORD: dbos - JDKVERSION: ${{ matrix.jdk-version }} SCALE_TEST: "true" - name: Test Summary diff --git a/.github/workflows/test_crdb.yml b/.github/workflows/test_crdb.yml index 27060012..42aa46f0 100644 --- a/.github/workflows/test_crdb.yml +++ b/.github/workflows/test_crdb.yml @@ -27,7 +27,6 @@ jobs: run: ./gradlew clean build env: PGPASSWORD: dbos - JDKVERSION: '17' DBOS_TEST_USE_COCKROACH_DB: 'true' - name: Test Summary @@ -41,7 +40,7 @@ jobs: uses: actions/upload-artifact@v7 if: always() with: - name: test-results-crdb-temurin-17 + name: test-results-crdb-temurin-25 path: | transact/build/reports/tests/ transact/build/test-results/ diff --git a/transact-cli/src/test/java/dev/dbos/transact/cli/PgContainer.java b/transact-cli/src/test/java/dev/dbos/transact/cli/PgContainer.java index 0b881a8b..3db061be 100644 --- a/transact-cli/src/test/java/dev/dbos/transact/cli/PgContainer.java +++ b/transact-cli/src/test/java/dev/dbos/transact/cli/PgContainer.java @@ -34,7 +34,7 @@ static PostgreSQLContainer acquire() { PERMITS.acquire(); var container = POOL.poll(); if (container == null) { - container = new PostgreSQLContainer("postgres:latest"); + container = new PostgreSQLContainer("postgres:18"); container.start(); } return container; diff --git a/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java index a466d222..2c1431dd 100644 --- a/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact-jdbi-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -38,7 +38,7 @@ static PostgreSQLContainer acquire() { PERMITS.acquire(); var container = POOL.poll(); if (container == null) { - container = new PostgreSQLContainer("postgres:latest"); + container = new PostgreSQLContainer("postgres:18"); container.start(); } return container; diff --git a/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java index a466d222..2c1431dd 100644 --- a/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact-jooq-step-factory/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -38,7 +38,7 @@ static PostgreSQLContainer acquire() { PERMITS.acquire(); var container = POOL.poll(); if (container == null) { - container = new PostgreSQLContainer("postgres:latest"); + container = new PostgreSQLContainer("postgres:18"); container.start(); } return container; diff --git a/transact-spring-boot-starter/src/test/java/dev/dbos/transact/spring/PgContainer.java b/transact-spring-boot-starter/src/test/java/dev/dbos/transact/spring/PgContainer.java index c64c739a..c8a2fa18 100644 --- a/transact-spring-boot-starter/src/test/java/dev/dbos/transact/spring/PgContainer.java +++ b/transact-spring-boot-starter/src/test/java/dev/dbos/transact/spring/PgContainer.java @@ -32,7 +32,7 @@ private static PostgreSQLContainer acquire() { PERMITS.acquire(); var container = POOL.poll(); if (container == null) { - container = new PostgreSQLContainer("postgres:latest"); + container = new PostgreSQLContainer("postgres:18"); container.start(); } return container; diff --git a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java index d7a8490e..dfa0483b 100644 --- a/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java +++ b/transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java @@ -47,13 +47,13 @@ public class SystemDatabase implements AutoCloseable { - public interface NotifcationRegistry { + public interface NotificationRegistry { Subscription subscribe(SignalKey.Message key); Subscription subscribe(SignalKey.Event key); } - public interface NotificationSource extends NotifcationRegistry { + public interface NotificationSource extends NotificationRegistry { void start(); void close(); @@ -121,7 +121,7 @@ private SystemDatabase( try { useListenNotify = isCockroach(dataSource) ? false : useListenNotify; } catch (SQLException e) { - logger.error("Failed to determine if dataSouce is CockroachDB", e); + logger.error("Failed to determine if dataSource is CockroachDB", e); useListenNotify = false; } diff --git a/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java index a4be1053..ee47d00b 100644 --- a/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java @@ -3,7 +3,7 @@ import dev.dbos.transact.Constants; import dev.dbos.transact.database.DbContext; import dev.dbos.transact.database.GetEventCaller; -import dev.dbos.transact.database.SystemDatabase.NotifcationRegistry; +import dev.dbos.transact.database.SystemDatabase.NotificationRegistry; import dev.dbos.transact.database.signal.SignalKey; import dev.dbos.transact.database.signal.SignalMap; import dev.dbos.transact.exceptions.DBOSNonExistentWorkflowException; @@ -163,7 +163,7 @@ public static Object recv( int timeoutStepId, String topic, Duration dbPollingInterval, - NotifcationRegistry notifcationRegistry) + NotificationRegistry notifcationRegistry) throws SQLException { if (Objects.requireNonNull(workflowId).isEmpty()) { @@ -421,7 +421,7 @@ public static Object getEvent( Duration timeout, @Nullable GetEventCaller caller, Duration dbPollingInterval, - NotifcationRegistry notifcationRegistry) + NotificationRegistry notifcationRegistry) throws SQLException { if (Objects.requireNonNull(workflowId).isEmpty()) { diff --git a/transact/src/main/java/dev/dbos/transact/database/signal/SignalMap.java b/transact/src/main/java/dev/dbos/transact/database/signal/SignalMap.java index 10cfb5ed..4f3565a2 100644 --- a/transact/src/main/java/dev/dbos/transact/database/signal/SignalMap.java +++ b/transact/src/main/java/dev/dbos/transact/database/signal/SignalMap.java @@ -41,7 +41,12 @@ public Subscription subscribe(K key, WakeReason reason) { () -> map.compute(key, (k, e) -> e != null && e.refs.decrementAndGet() == 0 ? null : e)); - entry.future.thenAccept(sub::complete); + entry.future.thenAccept( + r -> { + if (!sub.closed) { + sub.complete(r); + } + }); return sub; } diff --git a/transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java b/transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java index 5cedc6da..8d7788fb 100644 --- a/transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java +++ b/transact/src/main/java/dev/dbos/transact/database/signal/Subscription.java @@ -6,6 +6,7 @@ public class Subscription extends CompletableFuture implements AutoCloseable { private final Runnable onClose; + volatile boolean closed = false; public Subscription(Runnable onClose) { this.onClose = onClose; @@ -13,6 +14,7 @@ public Subscription(Runnable onClose) { @Override public void close() { + closed = true; onClose.run(); } } diff --git a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java index 254a5883..7f9bd918 100644 --- a/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java +++ b/transact/src/test/java/dev/dbos/transact/migrations/MigrationManagerTest.java @@ -166,29 +166,20 @@ void testRunMigrations_CreatesDatabaseIfNotExists() throws Exception { // Verify the database does not exist before running migrations try (var conn = - DriverManager.getConnection( - pair.url(), pgContainer.username(), pgContainer.password()); - var stmt = conn.prepareStatement("SELECT 1 FROM pg_database WHERE datname = ?")) { - stmt.setString(1, pair.database()); - try (var rs = stmt.executeQuery()) { - assertFalse( - rs.next(), - "Database '%s' should not exist before runMigrations".formatted(pair.database())); - } + DriverManager.getConnection(pair.url(), pgContainer.username(), pgContainer.password())) { + assertFalse( + databaseExists(conn, pair.database()), + "Database '%s' should not exist before runMigrations".formatted(pair.database())); } MigrationManager.runMigrations(dbosConfig); // Verify the database now exists after running migrations try (var conn = - DriverManager.getConnection( - pair.url(), pgContainer.username(), pgContainer.password()); - var stmt = conn.prepareStatement("SELECT 1 FROM pg_database WHERE datname = ?")) { - stmt.setString(1, pair.database()); - try (var rs = stmt.executeQuery()) { - assertTrue( - rs.next(), "Database '%s' should exist after runMigrations".formatted(pair.database())); - } + DriverManager.getConnection(pair.url(), pgContainer.username(), pgContainer.password())) { + assertTrue( + databaseExists(conn, pair.database()), + "Database '%s' should exist after runMigrations".formatted(pair.database())); } } @@ -370,6 +361,15 @@ static void assertTriggerExists(Connection conn, String triggerName, String sche } } + static boolean databaseExists(Connection conn, String dbName) throws Exception { + try (ResultSet rs = conn.getMetaData().getCatalogs()) { + while (rs.next()) { + if (dbName.equals(rs.getString("TABLE_CAT"))) return true; + } + return false; + } + } + static void assertTriggerAbsent(Connection conn, String triggerName) throws Exception { String sql = "SELECT 1 FROM pg_trigger t JOIN pg_class c ON t.tgrelid = c.oid" diff --git a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java index 54bbe876..6c029791 100644 --- a/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java +++ b/transact/src/test/java/dev/dbos/transact/utils/PgContainer.java @@ -26,11 +26,11 @@ public class PgContainer implements AutoCloseable { private static final Queue> POOL = new ConcurrentLinkedQueue<>(); public static PostgreSQLContainer getPG() { - return new PostgreSQLContainer("postgres:latest"); + return new PostgreSQLContainer("postgres:18"); } public static CockroachContainer getCRDB() { - return new CockroachContainer("cockroachdb/cockroach:latest"); + return new CockroachContainer("cockroachdb/cockroach:latest-v26.2"); } private static JdbcDatabaseContainer containerSupplier() { From 38913e82fe69707fbf4b013560358ce1a03ff126 Mon Sep 17 00:00:00 2001 From: Harry Pierson Date: Mon, 18 May 2026 08:46:24 -0700 Subject: [PATCH 27/27] fix timeout bug in getEvent + tests --- .../database/dao/NotificationsDAO.java | 3 +- .../transact/notifications/EventsTest.java | 42 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java b/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java index ee47d00b..e8ef29ac 100644 --- a/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java +++ b/transact/src/main/java/dev/dbos/transact/database/dao/NotificationsDAO.java @@ -462,7 +462,8 @@ public static Object getEvent( : timeout.minus(Duration.between(startTime, Instant.now())); if (sleepDuration.isNegative() || sleepDuration.isZero()) { - result = new GetEventResult(null, null); + var serialized = SerializationUtil.serializeValue(null, null, ctx.serializer()); + result = new GetEventResult(serialized.serializedValue(), serialized.serialization()); break; } diff --git a/transact/src/test/java/dev/dbos/transact/notifications/EventsTest.java b/transact/src/test/java/dev/dbos/transact/notifications/EventsTest.java index b03ec185..b3d25d5d 100644 --- a/transact/src/test/java/dev/dbos/transact/notifications/EventsTest.java +++ b/transact/src/test/java/dev/dbos/transact/notifications/EventsTest.java @@ -51,6 +51,8 @@ interface EventsService { void setMultipleEventsWorkflow(); Map getAllEventsWorkflow(String workflowId); + + Object recvWorkflow(String topic, Duration timeout); } class EventsServiceImpl implements EventsService { @@ -165,6 +167,12 @@ public Map getAllEventsWorkflow(String workflowId) { return dbos.getAllEvents(workflowId); } + @Workflow + @Override + public Object recvWorkflow(String topic, Duration timeout) { + return dbos.recv(topic, timeout).orElse(null); + } + public void resetLatches() { advanceGetLatch1 = new CountDownLatch(1); advanceGetLatch2 = new CountDownLatch(1); @@ -426,6 +434,40 @@ public void getAllEventsAppearsInSteps() throws Exception { } } + @Test + public void testGetEventTimeoutReplayable() throws Exception { + var wfid = UUID.randomUUID().toString(); + + // Run workflow — getEvent times out and records a step result with null output + try (var ctx = new WorkflowOptions(wfid).setContext()) { + assertNull(proxy.getEventWorkflow("nonexistent-wfid", "somekey", Duration.ofMillis(50))); + } + + // Simulate crash: keep step result, reset workflow to PENDING + DBUtils.setWorkflowState(dataSource, wfid, WorkflowState.PENDING.name()); + + // Recover — workflow body re-executes, getEvent step replays via toResult() + var handle = DBOSTestAccess.getDbosExecutor(dbos).executeWorkflowById(wfid, true, false); + assertNull(handle.getResult()); + } + + @Test + public void testRecvTimeoutReplayable() throws Exception { + var wfid = UUID.randomUUID().toString(); + + // Run workflow — recv times out and records a step result with null output + try (var ctx = new WorkflowOptions(wfid).setContext()) { + assertNull(proxy.recvWorkflow("some-topic", Duration.ofMillis(50))); + } + + // Simulate crash: keep step result, reset workflow to PENDING + DBUtils.setWorkflowState(dataSource, wfid, WorkflowState.PENDING.name()); + + // Recover — workflow body re-executes, recv step replays via toResult() + var handle = DBOSTestAccess.getDbosExecutor(dbos).executeWorkflowById(wfid, true, false); + assertNull(handle.getResult()); + } + @Test public void concurrency() throws Exception { ExecutorService executor = Executors.newFixedThreadPool(2);