diff --git a/infra/rewrite/core/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java b/infra/rewrite/core/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java index 9862e75336137..dec0363f21938 100644 --- a/infra/rewrite/core/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java +++ b/infra/rewrite/core/src/main/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngine.java @@ -20,9 +20,9 @@ import lombok.RequiredArgsConstructor; import org.apache.shardingsphere.database.connector.core.type.DatabaseType; import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation; -import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.binder.context.statement.type.dml.SelectStatementContext; +import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey; import org.apache.shardingsphere.infra.datanode.DataNode; import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit; @@ -139,7 +139,7 @@ private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteCo boolean containsDollarMarker = sqlRewriteContext.getSqlStatementContext() instanceof SelectStatementContext && ((SelectStatementContext) (sqlRewriteContext.getSqlStatementContext())).isContainsDollarParameterMarker(); for (RouteUnit each : routeUnits) { - sql.add(SQLUtils.trimSemicolon(new SQLBuilderEngine(sqlRewriteContext.getSql(), sqlRewriteContext.getSqlTokens(), each).buildSQL())); + sql.add(SQLUtils.trimSemicolon(new SQLBuilderEngine(sqlRewriteContext, each).buildSQL())); if (containsDollarMarker && !params.isEmpty()) { continue; } @@ -153,7 +153,7 @@ private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteCo } private String getActualSQL(final SQLRewriteContext sqlRewriteContext, final RouteUnit routeUnit) { - return new SQLBuilderEngine(sqlRewriteContext.getSql(), sqlRewriteContext.getSqlTokens(), routeUnit).buildSQL(); + return new SQLBuilderEngine(sqlRewriteContext, routeUnit).buildSQL(); } private List getParameters(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final RouteUnit routeUnit) { diff --git a/infra/rewrite/core/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/SQLBuilderEngine.java b/infra/rewrite/core/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/SQLBuilderEngine.java index 502362c634b12..939e9b924987f 100644 --- a/infra/rewrite/core/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/SQLBuilderEngine.java +++ b/infra/rewrite/core/src/main/java/org/apache/shardingsphere/infra/rewrite/sql/SQLBuilderEngine.java @@ -17,6 +17,7 @@ package org.apache.shardingsphere.infra.rewrite.sql; +import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext; import org.apache.shardingsphere.infra.rewrite.sql.impl.AbstractSQLBuilder; import org.apache.shardingsphere.infra.rewrite.sql.impl.DefaultSQLBuilder; import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder; @@ -36,8 +37,8 @@ public SQLBuilderEngine(final String sql, final List sqlTokens) { sqlBuilder = new DefaultSQLBuilder(sql, sqlTokens); } - public SQLBuilderEngine(final String sql, final List sqlTokens, final RouteUnit routeUnit) { - sqlBuilder = new RouteSQLBuilder(sql, sqlTokens, routeUnit); + public SQLBuilderEngine(final SQLRewriteContext sqlRewriteContext, final RouteUnit routeUnit) { + sqlBuilder = new RouteSQLBuilder(sqlRewriteContext.getSql(), sqlRewriteContext.getSqlTokens(), routeUnit); } /** diff --git a/infra/rewrite/core/src/test/java/org/apache/shardingsphere/infra/rewrite/sql/SQLBuilderEngineTest.java b/infra/rewrite/core/src/test/java/org/apache/shardingsphere/infra/rewrite/sql/SQLBuilderEngineTest.java index 18041946a7a7f..cc4c85931dd27 100644 --- a/infra/rewrite/core/src/test/java/org/apache/shardingsphere/infra/rewrite/sql/SQLBuilderEngineTest.java +++ b/infra/rewrite/core/src/test/java/org/apache/shardingsphere/infra/rewrite/sql/SQLBuilderEngineTest.java @@ -17,6 +17,7 @@ package org.apache.shardingsphere.infra.rewrite.sql; +import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext; import org.apache.shardingsphere.infra.rewrite.sql.fixture.SQLTokenFixture; import org.apache.shardingsphere.infra.route.context.RouteMapper; import org.apache.shardingsphere.infra.route.context.RouteUnit; @@ -27,6 +28,7 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; class SQLBuilderEngineTest { @@ -45,14 +47,20 @@ void assertCreateSQLBuilderEngineWithDefaultConstructorAndTokens() { @Test void assertCreateSQLBuilderEngineWithRouteUnitConstructor() { RouteUnit routeUnit = new RouteUnit(mock(RouteMapper.class), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); - SQLBuilderEngine actual = new SQLBuilderEngine("SELECT * FROM tbl WHERE id=?", Collections.emptyList(), routeUnit); + SQLRewriteContext sqlRewriteContext = mock(SQLRewriteContext.class); + when(sqlRewriteContext.getSql()).thenReturn("SELECT * FROM tbl WHERE id=?"); + when(sqlRewriteContext.getSqlTokens()).thenReturn(Collections.emptyList()); + SQLBuilderEngine actual = new SQLBuilderEngine(sqlRewriteContext, routeUnit); assertThat(actual.buildSQL(), is("SELECT * FROM tbl WHERE id=?")); } @Test void assertCreateSQLBuilderEngineWithRouteUnitConstructorAndTokens() { RouteUnit routeUnit = new RouteUnit(mock(RouteMapper.class), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); - SQLBuilderEngine actual = new SQLBuilderEngine("SELECT * FROM tbl WHERE id=?", Collections.singletonList(new SQLTokenFixture(14, 16)), routeUnit); + SQLRewriteContext sqlRewriteContext = mock(SQLRewriteContext.class); + when(sqlRewriteContext.getSql()).thenReturn("SELECT * FROM tbl WHERE id=?"); + when(sqlRewriteContext.getSqlTokens()).thenReturn(Collections.singletonList(new SQLTokenFixture(14, 16))); + SQLBuilderEngine actual = new SQLBuilderEngine(sqlRewriteContext, routeUnit); assertThat(actual.buildSQL(), is("SELECT * FROM XXX WHERE id=?")); }