Skip to content

Commit 8567271

Browse files
committed
Thread Safe JUnit Extensions
JUnit extensions are now fully thread safe through the use of `ExtensionContext.Store`. `SingleInstancePostgresExtension` has had its functionality extended to be able to run for the entire class instead of just per method.
1 parent 532a048 commit 8567271

9 files changed

Lines changed: 165 additions & 75 deletions

File tree

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies {
3434
And the following to your `gradle.properties`:
3535
```properties
3636
# Check this on https://central.sonatype.com/artifact/com.smushytaco/embedded-postgres/
37-
embedded_postgres_version = 3.0.1
37+
embedded_postgres_version = 3.0.2
3838
```
3939

4040
The default version of the embedded postgres is `PostgreSQL 18.0.0`, but you can change it by following the instructions described in [Postgres version](#postgres-version).
@@ -45,7 +45,7 @@ In your JUnit test just add:
4545

4646
```java
4747
@RegisterExtension
48-
SingleInstancePostgresExtension pg = EmbeddedPostgresExtension.singleInstance();
48+
final SingleInstancePostgresExtension pg = EmbeddedPostgresExtension.singleInstance();
4949
```
5050

5151
This simply has JUnit manage an instance of EmbeddedPostgres (start, stop). You can then use this to get a DataSource with: `pg.getEmbeddedPostgres().getPostgresDatabase();`
@@ -60,13 +60,13 @@ You can easily integrate Flyway or Liquibase database schema migration:
6060
##### Flyway
6161
```java
6262
@RegisterExtension
63-
PreparedDbExtension db = EmbeddedPostgresExtension.preparedDatabase(FlywayPreparer.forClasspathLocation("db/my-db-schema"));
63+
final PreparedDbExtension db = EmbeddedPostgresExtension.preparedDatabase(FlywayPreparer.forClasspathLocation("db/my-db-schema"));
6464
```
6565

6666
##### Liquibase
6767
```java
6868
@RegisterExtension
69-
PreparedDbExtension db = EmbeddedPostgresExtension.preparedDatabase(LiquibasePreparer.forClasspathLocation("liqui/master.xml"));
69+
final PreparedDbExtension db = EmbeddedPostgresExtension.preparedDatabase(LiquibasePreparer.forClasspathLocation("liqui/master.xml"));
7070
```
7171

7272
This will create an independent database for every test with the given schema loaded from the classpath.

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ org.gradle.configuration-cache = true
77
# Maven Properties
88
group = com.smushytaco
99
name = embedded-postgres
10-
version = 3.0.1
10+
version = 3.0.2
1111
##########################################################################
1212
# Java Version
1313
java_version = 25

src/main/java/com/smushytaco/postgres/embedded/DefaultPostgresBinaryResolver.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
* (e.g., {@code postgres-Linux-x86_64.txz}) and provides them as an {@link InputStream}.
4242
* It also includes fallback behavior for unsupported or emulated architectures.
4343
*/
44+
@SuppressWarnings("java:S6548")
4445
public class DefaultPostgresBinaryResolver implements PgBinaryResolver {
4546
private static final Logger logger = LoggerFactory.getLogger(DefaultPostgresBinaryResolver.class);
4647

src/main/java/com/smushytaco/postgres/embedded/EmbeddedPostgres.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ private void pgCtl(final Path dir, final String action) {
383383
system(pgCtl, args, null);
384384
}
385385

386+
@SuppressWarnings("java:S1141")
386387
private void cleanOldDataDirectories(final Path parentDirectory) {
387388
try (final Stream<Path> children = Files.list(parentDirectory)) {
388389
for (final Path dir : children.toList()) {
@@ -844,6 +845,7 @@ private void closeChannel(final Channel channel) {
844845
}
845846
}
846847

848+
@SuppressWarnings("java:S1141")
847849
private static Path prepareBinaries(final PgBinaryResolver pgBinaryResolver, final Path overrideWorkingDirectory) {
848850
PREPARE_BINARIES_LOCK.lock();
849851
try {

src/main/java/com/smushytaco/postgres/embedded/PreparedDbProvider.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ private static String randomAlphabetic(final int length) {
259259
.toLowerCase(Locale.ENGLISH);
260260
}
261261

262+
@SuppressWarnings("java:S2189")
262263
@Override
263264
public void run() {
264265
while (true) {
@@ -401,6 +402,7 @@ public Map<String, String> getProperties() {
401402
*
402403
* @return the exception that caused creation to fail, or {@code null} on success
403404
*/
405+
@SuppressWarnings("unused")
404406
public SQLException getException() {
405407
return ex;
406408
}

src/main/java/com/smushytaco/postgres/junit/PreparedDbExtension.java

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.sql.SQLException;
2828
import java.util.List;
2929
import java.util.concurrent.CopyOnWriteArrayList;
30+
import java.util.concurrent.atomic.AtomicBoolean;
3031
import java.util.function.Consumer;
3132

3233
/**
@@ -37,14 +38,15 @@
3738
* per test class or each test can get its own fresh database instance.
3839
*/
3940
public class PreparedDbExtension implements BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback {
40-
private final DatabasePreparer preparer;
41-
private boolean perClass = false;
42-
private volatile DataSource dataSource;
43-
private volatile PreparedDbProvider provider;
44-
private volatile ConnectionInfo connectionInfo;
41+
private final Object stateKey = new Object();
4542

43+
private final DatabasePreparer preparer;
4644
private final List<Consumer<EmbeddedPostgres.Builder>> builderCustomizers = new CopyOnWriteArrayList<>();
4745

46+
private final AtomicBoolean started = new AtomicBoolean(false);
47+
48+
private final ThreadLocal<State> current = new ThreadLocal<>();
49+
4850
PreparedDbExtension(final DatabasePreparer preparer) {
4951
if (preparer == null) throw new IllegalStateException("null preparer");
5052
this.preparer = preparer;
@@ -62,41 +64,49 @@ public class PreparedDbExtension implements BeforeAllCallback, AfterAllCallback,
6264
* @throws AssertionError if the extension has already been started
6365
*/
6466
public PreparedDbExtension customize(final Consumer<EmbeddedPostgres.Builder> customizer) {
65-
if (dataSource != null) throw new AssertionError("already started");
67+
if (started.get()) throw new AssertionError("already started");
6668
builderCustomizers.add(customizer);
6769
return this;
6870
}
6971

72+
@SuppressWarnings("DuplicatedCode")
7073
@Override
71-
public void beforeAll(@NonNull final ExtensionContext extensionContext) throws SQLException {
72-
provider = PreparedDbProvider.forPreparer(preparer, builderCustomizers);
73-
connectionInfo = provider.createNewDatabase();
74-
dataSource = provider.createDataSourceFromConnectionInfo(connectionInfo);
75-
perClass = true;
74+
public void beforeAll(@NonNull final ExtensionContext ctx) throws SQLException {
75+
final ExtensionContext.Store classStore = classStore(ctx);
76+
State state = classStore.get(stateKey, State.class);
77+
if (state == null) {
78+
state = createState();
79+
classStore.put(stateKey, state);
80+
started.set(true);
81+
}
82+
current.set(state);
7683
}
7784

7885
@Override
79-
public void afterAll(@NonNull final ExtensionContext extensionContext) {
80-
dataSource = null;
81-
connectionInfo = null;
82-
provider = null;
83-
perClass = false;
86+
public void afterAll(@NonNull final ExtensionContext ctx) {
87+
current.remove();
8488
}
8589

90+
@SuppressWarnings("DuplicatedCode")
8691
@Override
87-
public void beforeEach(@NonNull final ExtensionContext extensionContext) throws SQLException {
88-
if (perClass) return;
89-
provider = PreparedDbProvider.forPreparer(preparer, builderCustomizers);
90-
connectionInfo = provider.createNewDatabase();
91-
dataSource = provider.createDataSourceFromConnectionInfo(connectionInfo);
92+
public void beforeEach(@NonNull final ExtensionContext ctx) throws SQLException {
93+
final ExtensionContext.Store classStore = classStore(ctx);
94+
State state = classStore.get(stateKey, State.class);
95+
if (state == null) {
96+
final ExtensionContext.Store methodStore = methodStore(ctx);
97+
state = methodStore.get(stateKey, State.class);
98+
if (state == null) {
99+
state = createState();
100+
methodStore.put(stateKey, state);
101+
started.set(true);
102+
}
103+
}
104+
current.set(state);
92105
}
93106

94107
@Override
95-
public void afterEach(@NonNull final ExtensionContext extensionContext) {
96-
if (perClass) return;
97-
dataSource = null;
98-
connectionInfo = null;
99-
provider = null;
108+
public void afterEach(@NonNull final ExtensionContext ctx) {
109+
current.remove();
100110
}
101111

102112
/**
@@ -108,8 +118,9 @@ public void afterEach(@NonNull final ExtensionContext extensionContext) {
108118
* @throws AssertionError if the extension has not been initialized yet
109119
*/
110120
public DataSource getTestDatabase() {
111-
if (dataSource == null) throw new AssertionError("not initialized");
112-
return dataSource;
121+
final State s = current.get();
122+
if (s == null) throw new AssertionError("not initialized");
123+
return s.dataSource;
113124
}
114125

115126
/**
@@ -120,8 +131,9 @@ public DataSource getTestDatabase() {
120131
* @throws AssertionError if the extension has not been initialized yet
121132
*/
122133
public ConnectionInfo getConnectionInfo() {
123-
if (connectionInfo == null) throw new AssertionError("not initialized");
124-
return connectionInfo;
134+
final State s = current.get();
135+
if (s == null) throw new AssertionError("not initialized");
136+
return s.connectionInfo;
125137
}
126138

127139
/**
@@ -132,7 +144,34 @@ public ConnectionInfo getConnectionInfo() {
132144
* @throws AssertionError if the extension has not been initialized yet
133145
*/
134146
public PreparedDbProvider getDbProvider() {
135-
if (provider == null) throw new AssertionError("not initialized");
136-
return provider;
147+
final State s = current.get();
148+
if (s == null) throw new AssertionError("not initialized");
149+
return s.provider;
150+
}
151+
152+
private State createState() throws SQLException {
153+
final PreparedDbProvider provider = PreparedDbProvider.forPreparer(preparer, builderCustomizers);
154+
final ConnectionInfo connectionInfo = provider.createNewDatabase();
155+
return new State(provider, connectionInfo, provider.createDataSourceFromConnectionInfo(connectionInfo));
156+
}
157+
158+
private static ExtensionContext.Store classStore(final ExtensionContext ctx) {
159+
return ctx.getStore(ExtensionContext.Namespace.create(PreparedDbExtension.class, ctx.getRequiredTestClass()));
160+
}
161+
162+
private static ExtensionContext.Store methodStore(final ExtensionContext ctx) {
163+
return ctx.getStore(ExtensionContext.Namespace.create(PreparedDbExtension.class, ctx.getRequiredTestMethod()));
164+
}
165+
166+
private record State(PreparedDbProvider provider, ConnectionInfo connectionInfo, DataSource dataSource) implements AutoCloseable {
167+
@Override
168+
public void close() {
169+
try {
170+
if (provider instanceof final AutoCloseable autoCloseable) autoCloseable.close();
171+
} catch (final Exception _) { /* noop */ }
172+
try {
173+
if (dataSource instanceof final AutoCloseable autoCloseable) autoCloseable.close();
174+
} catch (final Exception _) { /* noop */ }
175+
}
137176
}
138177
}

0 commit comments

Comments
 (0)