Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.testcontainers.junit.jupiter;

import lombok.Getter;
import lombok.Synchronized;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.extension.AfterAllCallback;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
Expand All @@ -12,6 +14,7 @@
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
import org.junit.jupiter.api.extension.ExtensionContext.Store;
import org.junit.jupiter.api.extension.ExtensionContext.Store.CloseableResource;
import org.junit.jupiter.api.extension.TestInstancePostProcessor;
import org.junit.platform.commons.support.AnnotationSupport;
import org.junit.platform.commons.support.HierarchyTraversalMode;
import org.junit.platform.commons.support.ModifierSupport;
Expand All @@ -33,7 +36,13 @@
import java.util.stream.Stream;

public class TestcontainersExtension
implements BeforeEachCallback, BeforeAllCallback, AfterEachCallback, AfterAllCallback, ExecutionCondition {
implements
BeforeEachCallback,
BeforeAllCallback,
AfterEachCallback,
AfterAllCallback,
ExecutionCondition,
TestInstancePostProcessor {

private static final Namespace NAMESPACE = Namespace.create(TestcontainersExtension.class);

Expand All @@ -43,8 +52,23 @@ public class TestcontainersExtension

private final DockerAvailableDetector dockerDetector = new DockerAvailableDetector();

@Override
public void postProcessTestInstance(Object testInstance, ExtensionContext context) {
TestInstance.Lifecycle lifecycle = context.getTestInstanceLifecycle().orElse(null);
if (lifecycle == TestInstance.Lifecycle.PER_CLASS) {
beforeAllImpl(context);
}
}

@Override
public void beforeAll(ExtensionContext context) {
TestInstance.Lifecycle lifecycle = context.getTestInstanceLifecycle().orElse(null);
if (lifecycle != TestInstance.Lifecycle.PER_CLASS) {
beforeAllImpl(context);
}
}

private void beforeAllImpl(ExtensionContext context) {
Class<?> testClass = context
.getTestClass()
.orElseThrow(() -> {
Expand All @@ -71,16 +95,14 @@ private void startContainers(List<StoreAdapter> storeAdapters, Store store, Exte
return;
}

List<StoreAdapter> storedAdapters = storeAdapters
.stream()
.map(adapter -> (StoreAdapter) store.getOrComputeIfAbsent(adapter.getKey(), k -> adapter))
.collect(Collectors.toList());
if (isParallelExecutionEnabled(context)) {
Stream<Startable> startables = storeAdapters
.stream()
.map(storeAdapter -> {
store.getOrComputeIfAbsent(storeAdapter.getKey(), k -> storeAdapter);
return storeAdapter.container;
});
Startables.deepStart(startables).join();
Startables.deepStart(storedAdapters).join();
} else {
storeAdapters.forEach(adapter -> store.getOrComputeIfAbsent(adapter.getKey(), k -> adapter.start()));
storedAdapters.forEach(StoreAdapter::start);
}
}

Expand Down Expand Up @@ -260,7 +282,7 @@ private static StoreAdapter getContainerInstance(final Object testInstance, fina
* thereby letting the JUnit automatically stop containers once the current
* {@link ExtensionContext} is closed.
*/
private static class StoreAdapter implements CloseableResource, AutoCloseable {
private static class StoreAdapter implements Startable, CloseableResource, AutoCloseable {

@Getter
private String key;
Expand All @@ -272,14 +294,29 @@ private StoreAdapter(Class<?> declaringClass, String fieldName, Startable contai
this.container = container;
}

private StoreAdapter start() {
container.start();
return this;
private boolean started;

@Override
@Synchronized
public void start() {
if (!started) {
container.start();
started = true;
}
}

@Override
@Synchronized
public void stop() {
if (started) {
container.stop();
started = false;
}
}

@Override
public void close() {
container.stop();
stop();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.testcontainers.junit.jupiter;

import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestMethodOrder;
import org.testcontainers.lifecycle.Startable;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Verifies that instance {@link Container @Container} fields are started exactly once
* per test instance for both {@link TestInstance.Lifecycle#PER_CLASS} and
* {@link TestInstance.Lifecycle#PER_METHOD} lifecycles.
*/
@Testcontainers
class TestcontainersInstanceLifecycleTest {

@Container
private static final StartCountingMock staticContainer = new StartCountingMock();

@Nested
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
class PerClass {

@Container
private final StartCountingMock instanceContainer = new StartCountingMock();

@Test
@Order(1)
void first_test() {
assertThat(staticContainer.starts).isEqualTo(1);
assertThat(instanceContainer.starts).isEqualTo(1);
}

@Test
@Order(2)
void second_test() {
assertThat(staticContainer.starts).as("Static container should be started exactly once").isEqualTo(1);
assertThat(instanceContainer.starts)
.as("PER_CLASS instance container should be started for every test")
.isEqualTo(2);
}
}

@Nested
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
class PerMethod {

@Container
private final StartCountingMock instanceContainer = new StartCountingMock();

@Test
@Order(1)
void first_test() {
assertThat(staticContainer.starts).isEqualTo(1);
assertThat(instanceContainer.starts).isEqualTo(1);
}

@Test
@Order(2)
void second_test() {
assertThat(staticContainer.starts).as("Static container should be started exactly once").isEqualTo(1);
assertThat(instanceContainer.starts).isEqualTo(1);
}
}

static class StartCountingMock implements Startable {

int starts;

@Override
public void start() {
starts++;
}

@Override
public void stop() {}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.testcontainers.junit.jupiter;

import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.testcontainers.lifecycle.Startable;

import java.util.UUID;
import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Verifies that static {@link Container @Container} fields are available in non-static
* {@link MethodSource @MethodSource} factory methods with
* {@link TestInstance.Lifecycle#PER_CLASS PER_CLASS} lifecycle.
*/
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@Testcontainers
class TestcontainersPerClassPostProcessorTest {

@Container
private static final StartTrackingMock staticContainer = new StartTrackingMock();

@Container
private final StartTrackingMock instanceContainer = new StartTrackingMock();

private boolean staticStartedDuringMethodSource;

private boolean instanceStartedDuringMethodSource;

Stream<String> arguments() {
staticStartedDuringMethodSource = staticContainer.containerId != null;
instanceStartedDuringMethodSource = instanceContainer.containerId != null;
return Stream.of("a");
}

@ParameterizedTest
@MethodSource("arguments")
void containers_are_started_before_method_source(String argument) {
assertThat(staticStartedDuringMethodSource)
.as("Static container should be started before @MethodSource resolution")
.isTrue();
assertThat(instanceStartedDuringMethodSource)
.as("Instance container should NOT be started before @MethodSource resolution")
.isFalse();
}

static class StartTrackingMock implements Startable {

String containerId;

@Override
public void start() {
containerId = UUID.randomUUID().toString();
}

@Override
public void stop() {}
}
}