Skip to content

Commit 72f4959

Browse files
authored
Merge pull request #250 from jmarkerink/feat/switch_expression_operator
feat: implemented switch expression operator
2 parents 40e0ee8 + 1f403b2 commit 72f4959

2 files changed

Lines changed: 244 additions & 0 deletions

File tree

core/src/main/java/de/bwaldvogel/mongo/backend/aggregation/Expression.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,6 +1424,85 @@ Object apply(List<?> expressionValue, Document document) {
14241424
}
14251425
},
14261426

1427+
$switch {
1428+
@Override
1429+
Object apply(Object expressionValue, Document document) {
1430+
Document switchDocument = requireDocument(expressionValue, 40060);
1431+
1432+
// Validate that 'branches' field exists
1433+
if (!switchDocument.containsKey("branches")) {
1434+
throw new MongoServerError(40068, name() + " requires at least one branch");
1435+
}
1436+
1437+
// Validate unsupported parameters
1438+
List<String> supportedKeys = asList("branches", "default");
1439+
for (String key : switchDocument.keySet()) {
1440+
if (!supportedKeys.contains(key)) {
1441+
throw new MongoServerError(40067, name() + " found an unknown argument: " + key);
1442+
}
1443+
}
1444+
1445+
// Get and validate branches
1446+
Object branchesValue = switchDocument.get("branches");
1447+
if (!(branchesValue instanceof Collection<?>)) {
1448+
throw new MongoServerError(40061, name() + " expected an array for 'branches', found: " + describeType(branchesValue));
1449+
}
1450+
1451+
Collection<?> branches = (Collection<?>) branchesValue;
1452+
if (branches.isEmpty()) {
1453+
throw new MongoServerError(40060, name() + " requires at least one branch");
1454+
}
1455+
1456+
// Evaluate each branch
1457+
for (Object branchValue : branches) {
1458+
if (!(branchValue instanceof Document)) {
1459+
throw new MongoServerError(40062, name() + " expected each branch to be an object, found: " + describeType(branchValue));
1460+
}
1461+
1462+
Document branch = (Document) branchValue;
1463+
1464+
// Validate branch has required fields
1465+
if (!branch.containsKey("case")) {
1466+
throw new MongoServerError(40064, name() + " requires each branch have a 'case' expression");
1467+
}
1468+
if (!branch.containsKey("then")) {
1469+
throw new MongoServerError(40065, name() + " requires each branch have a 'then' expression");
1470+
}
1471+
1472+
// Validate branch has no extra fields
1473+
for (String key : branch.keySet()) {
1474+
if (!asList("case", "then").contains(key)) {
1475+
throw new MongoServerError(40063, name() + " found an unknown argument to a branch: " + key);
1476+
}
1477+
}
1478+
1479+
// Evaluate the case expression
1480+
Object caseExpression = branch.get("case");
1481+
Object caseResult = evaluate(caseExpression, document);
1482+
1483+
// If case is true, evaluate and return the then expression
1484+
if (Utils.isTrue(caseResult)) {
1485+
Object thenExpression = branch.get("then");
1486+
return evaluate(thenExpression, document);
1487+
}
1488+
}
1489+
1490+
// No case matched, check for default
1491+
if (switchDocument.containsKey("default")) {
1492+
Object defaultExpression = switchDocument.get("default");
1493+
return evaluate(defaultExpression, document);
1494+
}
1495+
1496+
// No case matched and no default provided
1497+
throw new MongoServerError(40066, name() + " could not find a matching branch for an input, and no default was specified.");
1498+
}
1499+
1500+
@Override
1501+
Object apply(List<?> expressionValue, Document document) {
1502+
throw new UnsupportedOperationException("must not be invoked");
1503+
}
1504+
},
1505+
14271506
$sqrt {
14281507
@Override
14291508
Object apply(List<?> expressionValue, Document document) {

test-common/src/main/java/de/bwaldvogel/mongo/backend/AbstractAggregationTest.java

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2733,6 +2733,171 @@ void testProjectWithCondition() throws Exception {
27332733
);
27342734
}
27352735

2736+
@Test
2737+
void testAggregateWithSwitch() throws Exception {
2738+
collection.insertOne(json("_id: 1, name: 'Dave', qty: 1"));
2739+
collection.insertOne(json("_id: 2, name: 'Carol', qty: 5"));
2740+
collection.insertOne(json("_id: 3, name: 'Bob', qty: 10"));
2741+
collection.insertOne(json("_id: 4, name: 'Alice', qty: 20"));
2742+
2743+
List<Document> pipeline = jsonList("""
2744+
$project: {
2745+
name: 1,
2746+
qtyDiscount: {
2747+
$switch: {
2748+
branches: [
2749+
{ case: { $gte: ['$qty', 10] }, then: 0.15 },
2750+
{ case: { $gte: ['$qty', 5] }, then: 0.10 },
2751+
{ case: { $gte: ['$qty', 1] }, then: 0.05 }
2752+
],
2753+
default: 0
2754+
}
2755+
}
2756+
}
2757+
""");
2758+
2759+
assertThat(collection.aggregate(pipeline))
2760+
.containsExactlyInAnyOrder(
2761+
json("_id: 1, name: 'Dave', qtyDiscount: 0.05"),
2762+
json("_id: 2, name: 'Carol', qtyDiscount: 0.10"),
2763+
json("_id: 3, name: 'Bob', qtyDiscount: 0.15"),
2764+
json("_id: 4, name: 'Alice', qtyDiscount: 0.15")
2765+
);
2766+
}
2767+
2768+
@Test
2769+
void testAggregateWithSwitchDefault() throws Exception {
2770+
collection.insertOne(json("_id: 1, status: 'active'"));
2771+
collection.insertOne(json("_id: 2, status: 'inactive'"));
2772+
collection.insertOne(json("_id: 3, status: 'unknown'"));
2773+
2774+
List<Document> pipeline = jsonList("""
2775+
$project: {
2776+
statusCode: {
2777+
$switch: {
2778+
branches: [
2779+
{ case: { $eq: ['$status', 'active'] }, then: 1 },
2780+
{ case: { $eq: ['$status', 'inactive'] }, then: 0 }
2781+
],
2782+
default: -1
2783+
}
2784+
}
2785+
}
2786+
""");
2787+
2788+
assertThat(collection.aggregate(pipeline))
2789+
.containsExactlyInAnyOrder(
2790+
json("_id: 1, statusCode: 1"),
2791+
json("_id: 2, statusCode: 0"),
2792+
json("_id: 3, statusCode: -1")
2793+
);
2794+
}
2795+
2796+
@Test
2797+
void testAggregateWithSwitchMissingDefault() throws Exception {
2798+
collection.insertOne(json("_id: 1, value: 100"));
2799+
2800+
List<Document> pipeline = jsonList("""
2801+
$project: {
2802+
result: {
2803+
$switch: {
2804+
branches: [
2805+
{ case: { $eq: ['$value', 50] }, then: 'fifty' }
2806+
]
2807+
}
2808+
}
2809+
}
2810+
""");
2811+
2812+
assertThatExceptionOfType(MongoCommandException.class)
2813+
.isThrownBy(() -> collection.aggregate(pipeline).first())
2814+
.withMessageContaining("$switch could not find a matching branch for an input, and no default was specified");
2815+
}
2816+
2817+
@Test
2818+
void testAggregateWithSwitchMissingBranches() throws Exception {
2819+
collection.insertOne(json("_id: 1, value: 100"));
2820+
2821+
List<Document> pipeline = jsonList("""
2822+
$project: {
2823+
result: {
2824+
$switch: {
2825+
default: 'none'
2826+
}
2827+
}
2828+
}
2829+
""");
2830+
2831+
assertThatExceptionOfType(MongoCommandException.class)
2832+
.isThrownBy(() -> collection.aggregate(pipeline).first())
2833+
.withMessageContaining("$switch requires at least one branch");
2834+
}
2835+
2836+
@Test
2837+
void testAggregateWithSwitchEmptyBranches() throws Exception {
2838+
collection.insertOne(json("_id: 1, value: 100"));
2839+
2840+
List<Document> pipeline = jsonList("""
2841+
$project: {
2842+
result: {
2843+
$switch: {
2844+
branches: [
2845+
],
2846+
default: 'none'
2847+
}
2848+
}
2849+
}
2850+
""");
2851+
2852+
assertThatExceptionOfType(MongoCommandException.class)
2853+
.isThrownBy(() -> collection.aggregate(pipeline).first())
2854+
.withMessageContaining("$switch requires at least one branch");
2855+
}
2856+
2857+
@Test
2858+
void testAggregateWithSwitchInvalidBranch() throws Exception {
2859+
collection.insertOne(json("_id: 1, value: 100"));
2860+
2861+
List<Document> pipeline = jsonList("""
2862+
$project: {
2863+
result: {
2864+
$switch: {
2865+
branches: [
2866+
{ case: { $eq: ['$value', 100] } }
2867+
],
2868+
default: 'none'
2869+
}
2870+
}
2871+
}
2872+
""");
2873+
2874+
assertThatExceptionOfType(MongoCommandException.class)
2875+
.isThrownBy(() -> collection.aggregate(pipeline).first())
2876+
.withMessageContaining("$switch requires each branch have a 'then' expression");
2877+
}
2878+
2879+
@Test
2880+
void testAggregateWithSwitchInvalidArgument() throws Exception {
2881+
collection.insertOne(json("_id: 1, value: 100"));
2882+
2883+
List<Document> pipeline = jsonList("""
2884+
$project: {
2885+
result: {
2886+
$switch: {
2887+
branches: [
2888+
{ case: { $eq: ['$value', 100] }, then: 'one hundred' }
2889+
],
2890+
default_value: 'none'
2891+
}
2892+
}
2893+
}
2894+
""");
2895+
2896+
assertThatExceptionOfType(MongoCommandException.class)
2897+
.isThrownBy(() -> collection.aggregate(pipeline).first())
2898+
.withMessageContaining("$switch found an unknown argument: default_value");
2899+
}
2900+
27362901
// https://github.com/bwaldvogel/mongo-java-server/issues/138
27372902
@Test
27382903
public void testAggregateWithGeoNear() throws Exception {

0 commit comments

Comments
 (0)