diff --git a/example/simple-table-demo/src/main/java/com/oceanbase/example/ObDirectLoadDemo.java b/example/simple-table-demo/src/main/java/com/oceanbase/example/ObDirectLoadDemo.java index 25b7ac3c..fd6c9610 100644 --- a/example/simple-table-demo/src/main/java/com/oceanbase/example/ObDirectLoadDemo.java +++ b/example/simple-table-demo/src/main/java/com/oceanbase/example/ObDirectLoadDemo.java @@ -53,6 +53,9 @@ public static void main(String[] args) { SimpleTest.run(); ParallelWriteTest.run(); MultiNodeWriteTest.run(); + P2PModeWriteTest.run(); + SimpleAbortTest.run(); + P2PModeAbortTest.run(); } private static void prepareTestTable() throws Exception { @@ -105,16 +108,16 @@ private static ObDirectLoadConnection buildConnection(int writeThreadNum) .enableParallelWrite(writeThreadNum).build(); } - private static ObDirectLoadStatement buildStatement(ObDirectLoadConnection connection) + private static ObDirectLoadStatement buildStatement(ObDirectLoadConnection connection, boolean isP2PMode) throws ObDirectLoadException { return connection.getStatementBuilder().setTableName(tableName).setDupAction(dupAction) - .setParallel(parallel).setQueryTimeout(timeout).build(); + .setParallel(parallel).setQueryTimeout(timeout).setIsP2PMode(isP2PMode).build(); } - private static ObDirectLoadStatement buildStatement(ObDirectLoadConnection connection, ObDirectLoadStatementExecutionId executionId) + private static ObDirectLoadStatement buildStatement(ObDirectLoadConnection connection, ObDirectLoadStatementExecutionId executionId, boolean isP2PMode) throws ObDirectLoadException { return connection.getStatementBuilder().setTableName(tableName).setDupAction(dupAction) - .setParallel(parallel).setQueryTimeout(timeout).setExecutionId(executionId).build(); + .setParallel(parallel).setQueryTimeout(timeout).setExecutionId(executionId).setIsP2PMode(isP2PMode).build(); } private static class SimpleTest { @@ -127,7 +130,7 @@ public static void run() { prepareTestTable(); connection = buildConnection(1); - statement = buildStatement(connection); + statement = buildStatement(connection, false); statement.begin(); @@ -192,7 +195,7 @@ public static void run() { prepareTestTable(); connection = buildConnection(parallel); - statement = buildStatement(connection); + statement = buildStatement(connection, false); statement.begin(); @@ -246,7 +249,7 @@ public void run() { executionId.decode(executionIdBytes); connection = buildConnection(1); - statement = buildStatement(connection, executionId); + statement = buildStatement(connection, executionId, false); ObDirectLoadBucket bucket = new ObDirectLoadBucket(); ObObj[] rowObjs = new ObObj[2]; @@ -277,7 +280,7 @@ public static void run() { prepareTestTable(); connection = buildConnection(1); - statement = buildStatement(connection); + statement = buildStatement(connection, false); statement.begin(); @@ -313,4 +316,228 @@ public static void run() { }; + private static class P2PModeWriteTest { + + private static class P2PNodeWriter implements Runnable { + + private final byte[] executionIdBytes; + private final int id; + private final AtomicInteger ref_cnt; + + P2PNodeWriter(byte[] executionIdBytes, int id, AtomicInteger ref_cnt) { + this.executionIdBytes = executionIdBytes; + this.id = id; + this.ref_cnt = ref_cnt; + } + + @Override + public void run() { + ObDirectLoadConnection connection = null; + ObDirectLoadStatement statement = null; + try { + ObDirectLoadStatementExecutionId executionId = new ObDirectLoadStatementExecutionId(); + executionId.decode(executionIdBytes); + + connection = buildConnection(1); + statement = buildStatement(connection, executionId, true); + + ObDirectLoadBucket bucket = new ObDirectLoadBucket(); + ObObj[] rowObjs = new ObObj[2]; + rowObjs[0] = new ObObj(ObObjType.ObInt32Type.getDefaultObjMeta(), id); + rowObjs[1] = new ObObj(ObObjType.ObInt32Type.getDefaultObjMeta(), id); + bucket.addRow(rowObjs); + statement.write(bucket); + + if (0 == ref_cnt.decrementAndGet()) { + statement.commit(); + } + + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + if (null != statement) { + statement.close(); + } + if (null != connection) { + connection.close(); + } + } + } + + }; + + public static void run() { + System.out.println("P2PModeWriteTest start"); + final int writeThreadNum = 10; + ObDirectLoadConnection connection = null; + ObDirectLoadStatement statement = null; + final AtomicInteger ref_cnt = new AtomicInteger(writeThreadNum); + try { + prepareTestTable(); + + connection = buildConnection(1); + statement = buildStatement(connection, true); + + statement.begin(); + + ObDirectLoadStatementExecutionId executionId = statement.getExecutionId(); + byte[] executionIdBytes = executionId.encode(); + + Thread[] threads = new Thread[writeThreadNum]; + for (int i = 0; i < threads.length; ++i) { + P2PNodeWriter NodeWriter = new P2PNodeWriter(executionIdBytes, i, ref_cnt); + Thread thread = new Thread(NodeWriter); + thread.start(); + threads[i] = thread; + } + for (int i = 0; i < threads.length; ++i) { + threads[i].join(); + } + queryTestTable(writeThreadNum); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + if (null != statement) { + statement.close(); + } + if (null != connection) { + connection.close(); + } + } + System.out.println("P2PModeWriteTest successful"); + } + + }; + + private static class SimpleAbortTest { + + public static void run() { + System.out.println("SimpleAbortTest start"); + ObDirectLoadConnection connection = null; + ObDirectLoadStatement statement = null; + try { + prepareTestTable(); + System.out.println("prepareTestTable"); + + connection = buildConnection(1); + statement = buildStatement(connection, false); + + statement.begin(); + + ObDirectLoadBucket bucket = new ObDirectLoadBucket(); + ObObj[] rowObjs = new ObObj[2]; + rowObjs[0] = new ObObj(ObObjType.ObInt32Type.getDefaultObjMeta(), 1); + rowObjs[1] = new ObObj(ObObjType.ObInt32Type.getDefaultObjMeta(), 2); + bucket.addRow(rowObjs); + statement.write(bucket); + + statement.abort(); + + queryTestTable(0); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + if (null != statement) { + statement.close(); + } + if (null != connection) { + connection.close(); + } + } + System.out.println("SimpleAbortTest successful"); + } + + }; + + private static class P2PModeAbortTest { + + + private static class AbortP2PNode implements Runnable { + + private final byte[] executionIdBytes; + private final int id; + + AbortP2PNode(byte[] executionIdBytes, int id) { + this.executionIdBytes = executionIdBytes; + this.id = id; + } + + @Override + public void run() { + ObDirectLoadConnection connection = null; + ObDirectLoadStatement statement = null; + try { + ObDirectLoadStatementExecutionId executionId = new ObDirectLoadStatementExecutionId(); + executionId.decode(executionIdBytes); + + connection = buildConnection(1); + statement = buildStatement(connection, executionId, true); + + ObDirectLoadBucket bucket = new ObDirectLoadBucket(); + ObObj[] rowObjs = new ObObj[2]; + rowObjs[0] = new ObObj(ObObjType.ObInt32Type.getDefaultObjMeta(), id); + rowObjs[1] = new ObObj(ObObjType.ObInt32Type.getDefaultObjMeta(), id); + bucket.addRow(rowObjs); + statement.write(bucket); + + statement.abort(); + + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + if (null != statement) { + statement.close(); + } + if (null != connection) { + connection.close(); + } + } + } + + }; + + public static void run() { + System.out.println("P2PModeAbortTest start"); + ObDirectLoadConnection connection = null; + ObDirectLoadStatement statement = null; + try { + prepareTestTable(); + + connection = buildConnection(1); + statement = buildStatement(connection, true); + + statement.begin(); + + ObDirectLoadBucket bucket = new ObDirectLoadBucket(); + ObObj[] rowObjs = new ObObj[2]; + rowObjs[0] = new ObObj(ObObjType.ObInt32Type.getDefaultObjMeta(), 1); + rowObjs[1] = new ObObj(ObObjType.ObInt32Type.getDefaultObjMeta(), 2); + bucket.addRow(rowObjs); + statement.write(bucket); + + ObDirectLoadStatementExecutionId executionId = statement.getExecutionId(); + byte[] executionIdBytes = executionId.encode(); + + AbortP2PNode abortP2PNode = new AbortP2PNode(executionIdBytes, 3); + Thread abortNodeThread = new Thread(abortP2PNode); + abortNodeThread.start(); + abortNodeThread.join(); + + queryTestTable(0); + + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + if (null != statement) { + statement.close(); + } + if (null != connection) { + connection.close(); + } + } + System.out.println("P2PModeAbortTest successful"); + } + + }; + } diff --git a/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadStatement.java b/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadStatement.java index 06a01666..f72ab6e0 100644 --- a/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadStatement.java +++ b/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadStatement.java @@ -87,7 +87,7 @@ public synchronized void init(Builder builder) throws ObDirectLoadException { connection.getProtocol().checkIsSupported(this); obTablePool = new ObDirectLoadConnection.ObTablePool(connection, logger, queryTimeout); obTablePool.init(); - executor = new ObDirectLoadStatementExecutor(this); + executor = new ObDirectLoadStatementExecutor(this, builder.isP2PMode); if (builder.executionId != null) { executor.resume(builder.executionId); } @@ -308,6 +308,24 @@ public void resume(ObDirectLoadStatementExecutionId executionId) throws ObDirect executor.resume(executionId); } + public ObDirectLoadStatementFuture abortAsync() { + try { + checkStatus(); + return executor.requestAbort(); + } catch (ObDirectLoadException e) { + logger.warn("statement abort failed", e); + return new ObDirectLoadStatementFailedFuture(this, e); + } + } + + public void abort() throws ObDirectLoadException { + ObDirectLoadStatementFuture future = abortAsync(); + future.await(); + if (!future.isSuccess()) { + throw future.cause(); + } + } + public static final class Builder { private final ObDirectLoadConnection connection; @@ -325,6 +343,7 @@ public static final class Builder { private ObDirectLoadTraceId traceId = null; private ObDirectLoadStatementExecutionId executionId = null; + private boolean isP2PMode = false; private static final long MAX_QUERY_TIMEOUT = Integer.MAX_VALUE; @@ -382,6 +401,11 @@ public ObDirectLoadTraceId getTraceId() { return traceId; } + public Builder setIsP2PMode(boolean isP2PMode) { + this.isP2PMode = isP2PMode; + return this; + } + public String toString() { return String .format( diff --git a/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutor.java b/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutor.java index f9551849..4399d3c3 100644 --- a/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutor.java +++ b/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutor.java @@ -58,13 +58,21 @@ public class ObDirectLoadStatementExecutor { private long taskId = 0; private ObAddr svrAddr = null; private ObDirectLoadException cause = null; // 失败原因 + private NodeRole nodeRole = NodeRole.PRIMARY; private AtomicInteger writingCount = new AtomicInteger(0); - public ObDirectLoadStatementExecutor(ObDirectLoadStatement statement) { + public enum NodeRole { + PRIMARY, WRITE_ONLY, P2P; + } + + public ObDirectLoadStatementExecutor(ObDirectLoadStatement statement, boolean isP2PMode) { this.statement = statement; this.traceId = statement.getTraceId(); this.logger = statement.getLogger(); + if (isP2PMode) { + this.nodeRole = NodeRole.P2P; + } } public ObDirectLoadStatement getStatement() { @@ -103,6 +111,11 @@ public synchronized ObDirectLoadStatementFuture begin() { logger.info("statement call begin"); ObDirectLoadStatementAsyncPromiseTask task = null; try { + if (NodeRole.PRIMARY != nodeRole && NodeRole.P2P != nodeRole) { + logger.warn("unexpected node role during begin process", nodeRole); + throw new ObDirectLoadUnexpectedException( + "unexpected node role during begin process"); + } compareAndSetState(NONE, BEGINNING, "begin"); } catch (ObDirectLoadException e) { logger.warn("statement begin failed", e); @@ -116,6 +129,7 @@ public synchronized ObDirectLoadStatementFuture begin() { logger.warn("statement start begin failed", e); cause = e; tryCompareAndSetState(BEGINNING, FAIL, "set begin failure"); + return new ObDirectLoadStatementFailedFuture(statement, e); } return task; } @@ -124,6 +138,11 @@ public synchronized ObDirectLoadStatementFuture commit() { logger.info("statement call commit"); ObDirectLoadStatementAsyncPromiseTask task = null; try { + if (NodeRole.PRIMARY != nodeRole && NodeRole.P2P != nodeRole) { + logger.warn("unexpected node role during commit process", nodeRole); + throw new ObDirectLoadUnexpectedException( + "unexpected node role during commit process"); + } compareAndSetState(LOADING, COMMITTING, "commit"); } catch (ObDirectLoadException e) { logger.warn("statement commit failed", e); @@ -137,6 +156,7 @@ public synchronized ObDirectLoadStatementFuture commit() { logger.warn("statement start commit failed", e); cause = e; tryCompareAndSetState(COMMITTING, FAIL, "set commit failure"); + return new ObDirectLoadStatementFailedFuture(statement, e); } return task; } @@ -174,7 +194,13 @@ public synchronized void resume(ObDirectLoadStatementExecutionId executionId) throws ObDirectLoadException { logger.info("statement call resume"); try { - compareAndSetState(NONE, LOADING_ONLY, "resume"); + if (NodeRole.P2P == nodeRole) { + compareAndSetState(NONE, LOADING, "resume in P2P mode"); + startHeartBeat(); + } else { + nodeRole = NodeRole.WRITE_ONLY; + compareAndSetState(NONE, LOADING_ONLY, "resume"); + } } catch (ObDirectLoadException e) { logger.warn("statement resume failed", e); throw e; @@ -272,6 +298,10 @@ public void close() { private synchronized void abortIfNeed() { logger.debug("statement abort if need"); + if (NodeRole.PRIMARY != nodeRole) { + //other roles have no ownership + return; + } if (abortFuture != null) { logger.debug("statement in abort"); return; @@ -317,6 +347,33 @@ private synchronized void abortIfNeed() { } } + public synchronized ObDirectLoadStatementFuture requestAbort() { + ObDirectLoadStatementFuture task = null; + final int state = stateFlag.get(); + try { + if (NodeRole.PRIMARY != nodeRole && NodeRole.P2P != nodeRole) { + logger.warn("unexpected node role during abort process", nodeRole); + throw new ObDirectLoadUnexpectedException( + "unexpected node role during abort process"); + } else if (NONE == state || BEGINNING == state) { + String message = "abort is not allowed, because state is " + state; + logger.warn(message); + throw new ObDirectLoadIllegalStateException(message); + } else { + if (this.abortFuture != null) { + logger.debug("already in abort process"); + task = abortFuture; + } else { + task = abort(); + } + } + } catch (ObDirectLoadException e) { + logger.warn("statement request abort failed", e); + return new ObDirectLoadStatementFailedFuture(statement, e); + } + return task; + } + private ObDirectLoadStatementFuture abort() { logger.info("statement call abort"); setState(ABORT);