Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.database.connector.core.type.DatabaseType;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.type.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
Expand Down Expand Up @@ -139,7 +139,7 @@ private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteCo
boolean containsDollarMarker = sqlRewriteContext.getSqlStatementContext() instanceof SelectStatementContext
&& ((SelectStatementContext) (sqlRewriteContext.getSqlStatementContext())).isContainsDollarParameterMarker();
for (RouteUnit each : routeUnits) {
sql.add(SQLUtils.trimSemicolon(new SQLBuilderEngine(sqlRewriteContext.getSql(), sqlRewriteContext.getSqlTokens(), each).buildSQL()));
sql.add(SQLUtils.trimSemicolon(new SQLBuilderEngine(sqlRewriteContext, each).buildSQL()));
if (containsDollarMarker && !params.isEmpty()) {
continue;
}
Expand All @@ -153,7 +153,7 @@ private SQLRewriteUnit createSQLRewriteUnit(final SQLRewriteContext sqlRewriteCo
}

private String getActualSQL(final SQLRewriteContext sqlRewriteContext, final RouteUnit routeUnit) {
return new SQLBuilderEngine(sqlRewriteContext.getSql(), sqlRewriteContext.getSqlTokens(), routeUnit).buildSQL();
return new SQLBuilderEngine(sqlRewriteContext, routeUnit).buildSQL();
}

private List<Object> getParameters(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext, final RouteUnit routeUnit) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.infra.rewrite.sql;

import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.sql.impl.AbstractSQLBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.impl.DefaultSQLBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder;
Expand All @@ -36,8 +37,8 @@ public SQLBuilderEngine(final String sql, final List<SQLToken> sqlTokens) {
sqlBuilder = new DefaultSQLBuilder(sql, sqlTokens);
}

public SQLBuilderEngine(final String sql, final List<SQLToken> sqlTokens, final RouteUnit routeUnit) {
sqlBuilder = new RouteSQLBuilder(sql, sqlTokens, routeUnit);
public SQLBuilderEngine(final SQLRewriteContext sqlRewriteContext, final RouteUnit routeUnit) {
sqlBuilder = new RouteSQLBuilder(sqlRewriteContext.getSql(), sqlRewriteContext.getSqlTokens(), routeUnit);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.infra.rewrite.sql;

import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.sql.fixture.SQLTokenFixture;
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
Expand All @@ -27,6 +28,7 @@
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class SQLBuilderEngineTest {

Expand All @@ -45,14 +47,20 @@ void assertCreateSQLBuilderEngineWithDefaultConstructorAndTokens() {
@Test
void assertCreateSQLBuilderEngineWithRouteUnitConstructor() {
RouteUnit routeUnit = new RouteUnit(mock(RouteMapper.class), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
SQLBuilderEngine actual = new SQLBuilderEngine("SELECT * FROM tbl WHERE id=?", Collections.emptyList(), routeUnit);
SQLRewriteContext sqlRewriteContext = mock(SQLRewriteContext.class);
when(sqlRewriteContext.getSql()).thenReturn("SELECT * FROM tbl WHERE id=?");
when(sqlRewriteContext.getSqlTokens()).thenReturn(Collections.emptyList());
SQLBuilderEngine actual = new SQLBuilderEngine(sqlRewriteContext, routeUnit);
assertThat(actual.buildSQL(), is("SELECT * FROM tbl WHERE id=?"));
}

@Test
void assertCreateSQLBuilderEngineWithRouteUnitConstructorAndTokens() {
RouteUnit routeUnit = new RouteUnit(mock(RouteMapper.class), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
SQLBuilderEngine actual = new SQLBuilderEngine("SELECT * FROM tbl WHERE id=?", Collections.singletonList(new SQLTokenFixture(14, 16)), routeUnit);
SQLRewriteContext sqlRewriteContext = mock(SQLRewriteContext.class);
when(sqlRewriteContext.getSql()).thenReturn("SELECT * FROM tbl WHERE id=?");
when(sqlRewriteContext.getSqlTokens()).thenReturn(Collections.singletonList(new SQLTokenFixture(14, 16)));
SQLBuilderEngine actual = new SQLBuilderEngine(sqlRewriteContext, routeUnit);
assertThat(actual.buildSQL(), is("SELECT * FROM XXX WHERE id=?"));
}

Expand Down
Loading