Skip to content

Commit 0a1d373

Browse files
committed
Adding SparkSQL to DatabaseType and adding SaprkSQLPaginationQueryProvider
1 parent 10f9b66 commit 0a1d373

4 files changed

Lines changed: 152 additions & 1 deletion

File tree

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2006-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.batch.infrastructure.item.database.support;
18+
19+
import org.springframework.batch.infrastructure.item.database.PagingQueryProvider;
20+
import org.springframework.util.StringUtils;
21+
22+
/**
23+
* SparkSQL implementation of a {@link PagingQueryProvider} using database specific features.
24+
*
25+
* @author Rahul Kumar
26+
* @since 2.0
27+
*/
28+
public class SparkSqlPagingQueryProvider extends AbstractSqlPagingQueryProvider {
29+
30+
@Override
31+
public String generateFirstPageQuery(int pageSize) {
32+
return SqlPagingQueryUtils.generateLimitSqlQuery(this, false, buildLimitClause(pageSize));
33+
}
34+
35+
@Override
36+
public String generateRemainingPagesQuery(int pageSize) {
37+
if (StringUtils.hasText(getGroupClause())) {
38+
return SqlPagingQueryUtils.generateLimitGroupedSqlQuery(this, buildLimitClause(pageSize));
39+
}
40+
else {
41+
return SqlPagingQueryUtils.generateLimitSqlQuery(this, true, buildLimitClause(pageSize));
42+
}
43+
}
44+
45+
private String buildLimitClause(int pageSize) {
46+
return "LIMIT " + pageSize;
47+
}
48+
49+
}

spring-batch-infrastructure/src/main/java/org/springframework/batch/infrastructure/item/database/support/SqlPagingQueryProviderFactoryBean.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import static org.springframework.batch.infrastructure.support.DatabaseType.SQLITE;
3131
import static org.springframework.batch.infrastructure.support.DatabaseType.SQLSERVER;
3232
import static org.springframework.batch.infrastructure.support.DatabaseType.SYBASE;
33+
import static org.springframework.batch.infrastructure.support.DatabaseType.SPARKSQL;
3334

3435
import java.util.HashMap;
3536
import java.util.LinkedHashMap;
@@ -89,6 +90,7 @@ public class SqlPagingQueryProviderFactoryBean implements FactoryBean<PagingQuer
8990
providers.put(SQLITE, new SqlitePagingQueryProvider());
9091
providers.put(SQLSERVER, new SqlServerPagingQueryProvider());
9192
providers.put(SYBASE, new SybasePagingQueryProvider());
93+
providers.put(SPARKSQL, new SparkSqlPagingQueryProvider());
9294
}
9395

9496
/**

spring-batch-infrastructure/src/main/java/org/springframework/batch/infrastructure/support/DatabaseType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public enum DatabaseType {
4343

4444
DERBY("Apache Derby"), DB2("DB2"), DB2VSE("DB2VSE"), DB2ZOS("DB2ZOS"), DB2AS400("DB2AS400"),
4545
HSQL("HSQL Database Engine"), SQLSERVER("Microsoft SQL Server"), MYSQL("MySQL"), ORACLE("Oracle"),
46-
POSTGRES("PostgreSQL"), SYBASE("Sybase"), H2("H2"), SQLITE("SQLite"), HANA("HDB"), MARIADB("MariaDB");
46+
POSTGRES("PostgreSQL"), SYBASE("Sybase"), H2("H2"), SQLITE("SQLite"), HANA("HDB"), MARIADB("MariaDB"), SPARKSQL("SparkSQL");
4747

4848
private static final Map<String, DatabaseType> DATABASE_TYPES = Arrays.stream(DatabaseType.values())
4949
.collect(Collectors.toMap(DatabaseType::getProductName, Function.identity()));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright 2006-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.batch.infrastructure.item.database.support;
17+
18+
import org.junit.jupiter.api.Test;
19+
import org.springframework.batch.infrastructure.item.database.Order;
20+
21+
import java.util.HashMap;
22+
import java.util.Map;
23+
24+
import static org.junit.jupiter.api.Assertions.assertEquals;
25+
26+
/**
27+
* @author Rahul Kumar
28+
*/
29+
class SparkSqlPagingQueryProviderTests extends AbstractSqlPagingQueryProviderTests {
30+
31+
SparkSqlPagingQueryProviderTests() {
32+
pagingQueryProvider = new SparkSqlPagingQueryProvider();
33+
}
34+
35+
@Test
36+
@Override
37+
void testGenerateFirstPageQuery() {
38+
String sql = "SELECT id, name, age FROM foo WHERE bar = 1 ORDER BY id ASC LIMIT 100";
39+
String s = pagingQueryProvider.generateFirstPageQuery(pageSize);
40+
assertEquals(sql, s);
41+
}
42+
43+
@Test
44+
@Override
45+
void testGenerateRemainingPagesQuery() {
46+
String sql = "SELECT id, name, age FROM foo WHERE (bar = 1) AND ((id > ?)) ORDER BY id ASC LIMIT 100";
47+
String s = pagingQueryProvider.generateRemainingPagesQuery(pageSize);
48+
assertEquals(sql, s);
49+
}
50+
51+
@Override
52+
@Test
53+
void testGenerateFirstPageQueryWithGroupBy() {
54+
pagingQueryProvider.setGroupClause("dep");
55+
String sql = "SELECT id, name, age FROM foo WHERE bar = 1 GROUP BY dep ORDER BY id ASC LIMIT 100";
56+
String s = pagingQueryProvider.generateFirstPageQuery(pageSize);
57+
assertEquals(sql, s);
58+
}
59+
60+
@Override
61+
@Test
62+
void testGenerateRemainingPagesQueryWithGroupBy() {
63+
pagingQueryProvider.setGroupClause("dep");
64+
String sql = "SELECT * FROM (SELECT id, name, age FROM foo WHERE bar = 1 GROUP BY dep) AS MAIN_QRY WHERE ((id > ?)) ORDER BY id ASC LIMIT 100";
65+
String s = pagingQueryProvider.generateRemainingPagesQuery(pageSize);
66+
assertEquals(sql, s);
67+
}
68+
69+
@Test
70+
void testFirstPageSqlWithAliases() {
71+
Map<String, Order> sorts = new HashMap<>();
72+
sorts.put("owner.id", Order.ASCENDING);
73+
74+
this.pagingQueryProvider = new SparkSqlPagingQueryProvider();
75+
this.pagingQueryProvider.setSelectClause("SELECT owner.id as ownerid, first_name, last_name, dog_name ");
76+
this.pagingQueryProvider.setFromClause("FROM dog_owner owner INNER JOIN dog ON owner.id = dog.id ");
77+
this.pagingQueryProvider.setSortKeys(sorts);
78+
79+
String firstPage = this.pagingQueryProvider.generateFirstPageQuery(5);
80+
String remainingPagesQuery = this.pagingQueryProvider.generateRemainingPagesQuery(5);
81+
82+
assertEquals(
83+
"SELECT owner.id as ownerid, first_name, last_name, dog_name FROM dog_owner owner INNER JOIN dog ON owner.id = dog.id ORDER BY owner.id ASC LIMIT 5",
84+
firstPage);
85+
assertEquals(
86+
"SELECT owner.id as ownerid, first_name, last_name, dog_name FROM dog_owner owner INNER JOIN dog ON owner.id = dog.id WHERE ((owner.id > ?)) ORDER BY owner.id ASC LIMIT 5",
87+
remainingPagesQuery);
88+
}
89+
90+
@Override
91+
String getFirstPageSqlWithMultipleSortKeys() {
92+
return "SELECT id, name, age FROM foo WHERE bar = 1 ORDER BY name ASC, id DESC LIMIT 100";
93+
}
94+
95+
@Override
96+
String getRemainingSqlWithMultipleSortKeys() {
97+
return "SELECT id, name, age FROM foo WHERE (bar = 1) AND ((name > ?) OR (name = ? AND id < ?)) ORDER BY name ASC, id DESC LIMIT 100";
98+
}
99+
100+
}

0 commit comments

Comments
 (0)