diff --git a/example/simple-table-demo/src/main/java/com/oceanbase/example/ObDirectLoadDemo.java b/example/simple-table-demo/src/main/java/com/oceanbase/example/ObDirectLoadDemo.java index 27af3893..25b7ac3c 100644 --- a/example/simple-table-demo/src/main/java/com/oceanbase/example/ObDirectLoadDemo.java +++ b/example/simple-table-demo/src/main/java/com/oceanbase/example/ObDirectLoadDemo.java @@ -111,6 +111,12 @@ private static ObDirectLoadStatement buildStatement(ObDirectLoadConnection conne .setParallel(parallel).setQueryTimeout(timeout).build(); } + private static ObDirectLoadStatement buildStatement(ObDirectLoadConnection connection, ObDirectLoadStatementExecutionId executionId) + throws ObDirectLoadException { + return connection.getStatementBuilder().setTableName(tableName).setDupAction(dupAction) + .setParallel(parallel).setQueryTimeout(timeout).setExecutionId(executionId).build(); + } + private static class SimpleTest { public static void run() { @@ -240,9 +246,7 @@ public void run() { executionId.decode(executionIdBytes); connection = buildConnection(1); - statement = buildStatement(connection); - - statement.resume(executionId); + statement = buildStatement(connection, executionId); ObDirectLoadBucket bucket = new ObDirectLoadBucket(); ObObj[] rowObjs = new ObObj[2]; diff --git a/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadConnection.java b/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadConnection.java index e4b8736b..38e927a1 100644 --- a/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadConnection.java +++ b/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadConnection.java @@ -231,7 +231,8 @@ public void executeWithConnection(final ObDirectLoadRpc rpc, ObTable table, long } } - public synchronized ObDirectLoadStatement createStatement() throws ObDirectLoadException { + public synchronized ObDirectLoadStatement createStatement(ObDirectLoadTraceId traceId) + throws ObDirectLoadException { if (!isInited) { logger.warn("connection not init"); throw new ObDirectLoadIllegalStateException("connection not init"); @@ -240,7 +241,7 @@ public synchronized ObDirectLoadStatement createStatement() throws ObDirectLoadE logger.warn("connection is closed"); throw new ObDirectLoadIllegalStateException("connection is closed"); } - ObDirectLoadStatement stmt = new ObDirectLoadStatement(this); + ObDirectLoadStatement stmt = new ObDirectLoadStatement(this, traceId); this.statementList.addLast(stmt); return stmt; } @@ -257,7 +258,9 @@ ObDirectLoadStatement buildStatement(ObDirectLoadStatement.Builder builder) throws ObDirectLoadException { ObDirectLoadStatement stmt = null; try { - stmt = createStatement(); + final ObDirectLoadTraceId traceId = builder.getTraceId() != null ? builder.getTraceId() + : ObDirectLoadTraceId.generateTraceId(); + stmt = createStatement(traceId); stmt.init(builder); } catch (Exception e) { logger.warn("build statement failed, args:" + builder, e); diff --git a/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadStatement.java b/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadStatement.java index 13686bb9..06a01666 100644 --- a/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadStatement.java +++ b/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadStatement.java @@ -55,9 +55,9 @@ public class ObDirectLoadStatement { private ObDirectLoadStatementExecutor executor = null; private long startQueryTimeMillis = 0; - ObDirectLoadStatement(ObDirectLoadConnection connection) { + ObDirectLoadStatement(ObDirectLoadConnection connection, ObDirectLoadTraceId traceId) { this.connection = connection; - this.traceId = ObDirectLoadTraceId.generateTraceId(); + this.traceId = traceId; this.logger = ObDirectLoadLogger.getLogger(this.traceId); } @@ -88,6 +88,9 @@ public synchronized void init(Builder builder) throws ObDirectLoadException { obTablePool = new ObDirectLoadConnection.ObTablePool(connection, logger, queryTimeout); obTablePool.init(); executor = new ObDirectLoadStatementExecutor(this); + if (builder.executionId != null) { + executor.resume(builder.executionId); + } startQueryTimeMillis = System.currentTimeMillis(); isInited = true; logger.info("statement init successful, args:" + builder); @@ -294,6 +297,7 @@ public ObDirectLoadStatementExecutionId getExecutionId() throws ObDirectLoadExce return executor.getExecutionId(); } + @Deprecated public void resume(ObDirectLoadStatementExecutionId executionId) throws ObDirectLoadException { if (executionId == null || !executionId.isValid()) { logger.warn("Param 'executionId' must not be null or invalid, value:" + executionId); @@ -306,20 +310,23 @@ public void resume(ObDirectLoadStatementExecutionId executionId) throws ObDirect public static final class Builder { - private final ObDirectLoadConnection connection; + private final ObDirectLoadConnection connection; - private String tableName = null; - private String[] columnNames = null; - private String[] partitionNames = null; - private ObLoadDupActionType dupAction = ObLoadDupActionType.INVALID_MODE; + private String tableName = null; + private String[] columnNames = null; + private String[] partitionNames = null; + private ObLoadDupActionType dupAction = ObLoadDupActionType.INVALID_MODE; - private int parallel = 0; - private long queryTimeout = 0; + private int parallel = 0; + private long queryTimeout = 0; - private long maxErrorRowCount = 0; - private String loadMethod = "full"; + private long maxErrorRowCount = 0; + private String loadMethod = "full"; - private static final long MAX_QUERY_TIMEOUT = Integer.MAX_VALUE; + private ObDirectLoadTraceId traceId = null; + private ObDirectLoadStatementExecutionId executionId = null; + + private static final long MAX_QUERY_TIMEOUT = Integer.MAX_VALUE; Builder(ObDirectLoadConnection connection) { this.connection = connection; @@ -365,12 +372,22 @@ public Builder setLoadMethod(String loadMethod) { return this; } + public Builder setExecutionId(ObDirectLoadStatementExecutionId executionId) { + this.traceId = executionId.getTraceId(); + this.executionId = executionId; + return this; + } + + public ObDirectLoadTraceId getTraceId() { + return traceId; + } + public String toString() { return String .format( - "{tableName:%s, columnNames:%s, partitionNames:%s, dupAction:%s, parallel:%d, queryTimeout:%d, maxErrorRowCount:%d, loadMethod:%s}", + "{tableName:%s, columnNames:%s, partitionNames:%s, dupAction:%s, parallel:%d, queryTimeout:%d, maxErrorRowCount:%d, loadMethod:%s, executionId:%s}", tableName, Arrays.toString(columnNames), Arrays.toString(partitionNames), - dupAction, parallel, queryTimeout, maxErrorRowCount, loadMethod); + dupAction, parallel, queryTimeout, maxErrorRowCount, loadMethod, executionId); } public ObDirectLoadStatement build() throws ObDirectLoadException { diff --git a/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadTraceId.java b/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadTraceId.java index f74559b3..e7976f3d 100644 --- a/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadTraceId.java +++ b/src/main/java/com/alipay/oceanbase/rpc/direct_load/ObDirectLoadTraceId.java @@ -20,6 +20,12 @@ import java.net.InetAddress; import java.util.concurrent.atomic.AtomicLong; +import com.alipay.oceanbase.rpc.util.ObByteBuf; +import com.alipay.oceanbase.rpc.util.Serialization; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + public class ObDirectLoadTraceId { private final long uniqueId; @@ -30,10 +36,6 @@ public ObDirectLoadTraceId(long uniqueId, long sequence) { this.sequence = sequence; } - public String toString() { - return String.format("Y%X-%016X", uniqueId, sequence); - } - public long getUniqueId() { return uniqueId; } @@ -42,6 +44,40 @@ public long getSequence() { return sequence; } + public String toString() { + return String.format("Y%X-%016X", uniqueId, sequence); + } + + public byte[] encode() { + int needBytes = (int) getEncodedSize(); + ObByteBuf buf = new ObByteBuf(needBytes); + encode(buf); + return buf.bytes; + } + + public void encode(ObByteBuf buf) { + Serialization.encodeVi64(buf, uniqueId); + Serialization.encodeVi64(buf, sequence); + } + + public static ObDirectLoadTraceId decode(ByteBuf buf) { + long uniqueId = Serialization.decodeVi64(buf); + long sequence = Serialization.decodeVi64(buf); + return new ObDirectLoadTraceId(uniqueId, sequence); + } + + public static ObDirectLoadTraceId decode(byte[] bytes) { + ByteBuf buf = Unpooled.wrappedBuffer(bytes); + return decode(buf); + } + + public int getEncodedSize() { + int len = 0; + len += Serialization.getNeedBytes(uniqueId); + len += Serialization.getNeedBytes(sequence); + return len; + } + public static final ObDirectLoadTraceId DEFAULT_TRACE_ID; public static TraceIdGenerator traceIdGenerator; diff --git a/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutionId.java b/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutionId.java index d98dcac9..5c8a6b07 100644 --- a/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutionId.java +++ b/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutionId.java @@ -17,6 +17,7 @@ package com.alipay.oceanbase.rpc.direct_load.execution; +import com.alipay.oceanbase.rpc.direct_load.ObDirectLoadTraceId; import com.alipay.oceanbase.rpc.direct_load.exception.ObDirectLoadException; import com.alipay.oceanbase.rpc.direct_load.exception.ObDirectLoadIllegalArgumentException; import com.alipay.oceanbase.rpc.protocol.payload.impl.ObAddr; @@ -28,9 +29,11 @@ public class ObDirectLoadStatementExecutionId { - private long tableId = 0; - private long taskId = 0; - private ObAddr svrAddr = new ObAddr(); + private long tableId = 0; + private long taskId = 0; + private ObAddr svrAddr = new ObAddr(); + + private ObDirectLoadTraceId traceId = null; public ObDirectLoadStatementExecutionId() { } @@ -46,6 +49,20 @@ public ObDirectLoadStatementExecutionId(long tableId, long taskId, ObAddr svrAdd this.svrAddr = svrAddr; } + public ObDirectLoadStatementExecutionId(long tableId, long taskId, ObAddr svrAddr, + ObDirectLoadTraceId traceId) + throws ObDirectLoadException { + if (tableId < 0 || taskId <= 0 || svrAddr == null || traceId == null) { + throw new ObDirectLoadIllegalArgumentException(String.format( + "invalid args, tableId:%d, taskId:%d, svrAddr:%s, traceId:%s", tableId, taskId, + svrAddr, traceId)); + } + this.tableId = tableId; + this.taskId = taskId; + this.svrAddr = svrAddr; + this.traceId = traceId; + } + public long getTableId() { return tableId; } @@ -58,12 +75,17 @@ public ObAddr getSvrAddr() { return svrAddr; } + public ObDirectLoadTraceId getTraceId() { + return traceId; + } + public boolean isValid() { return tableId >= 0 && taskId > 0 && svrAddr.isValid(); } public String toString() { - return String.format("{tableId:%d, taskId:%d, svrAddr:%s}", tableId, taskId, svrAddr); + return String.format("{tableId:%d, taskId:%d, svrAddr:%s, traceId:%s}", tableId, taskId, + svrAddr, traceId); } public byte[] encode() { @@ -77,12 +99,18 @@ public void encode(ObByteBuf buf) { Serialization.encodeVi64(buf, tableId); Serialization.encodeVi64(buf, taskId); svrAddr.encode(buf); + if (traceId != null) { + traceId.encode(buf); + } } public ObDirectLoadStatementExecutionId decode(ByteBuf buf) { tableId = Serialization.decodeVi64(buf); taskId = Serialization.decodeVi64(buf); svrAddr.decode(buf); + if (buf.readableBytes() > 0) { + traceId = ObDirectLoadTraceId.decode(buf); + } return this; } @@ -96,6 +124,9 @@ public int getEncodedSize() { len += Serialization.getNeedBytes(tableId); len += Serialization.getNeedBytes(taskId); len += svrAddr.getEncodedSize(); + if (traceId != null) { + len += traceId.getEncodedSize(); + } return len; } diff --git a/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutor.java b/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutor.java index 5792836f..f9551849 100644 --- a/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutor.java +++ b/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementExecutor.java @@ -59,6 +59,8 @@ public class ObDirectLoadStatementExecutor { private ObAddr svrAddr = null; private ObDirectLoadException cause = null; // 失败原因 + private AtomicInteger writingCount = new AtomicInteger(0); + public ObDirectLoadStatementExecutor(ObDirectLoadStatement statement) { this.statement = statement; this.traceId = statement.getTraceId(); @@ -164,7 +166,7 @@ public synchronized void detach() throws ObDirectLoadException { public ObDirectLoadStatementExecutionId getExecutionId() throws ObDirectLoadException { checkState(LOADING, "getExecutionId"); ObDirectLoadStatementExecutionId executionId = new ObDirectLoadStatementExecutionId( - tableId, taskId, svrAddr); + tableId, taskId, svrAddr, traceId); return executionId; } @@ -247,6 +249,25 @@ public void close() { logger.warn("statement abort failed", e); } } + // 如果还有写没结束, 等待写结束 + if (writingCount.get() > 0) { + logger.info("statement close wait write"); + try { + final long startTimeMillis = System.currentTimeMillis(); + long loopCnt = 0; + while (writingCount.get() > 0) { + Thread.sleep(10); + ++loopCnt; + if (loopCnt % 100 == 0) { + final long curTimeMillis = System.currentTimeMillis(); + logger.warn("statement has been wait write for " + + (curTimeMillis - startTimeMillis) + " ms"); + } + } + } catch (Exception e) { + logger.warn("statement wait write failed", e); + } + } } private synchronized void abortIfNeed() { @@ -343,15 +364,23 @@ void stopHeartBeat() { public void write(ObDirectLoadBucket bucket) throws ObDirectLoadException { checkState(LOADING, LOADING_ONLY, "write"); - ObDirectLoadStatementPromiseTask task = new ObDirectLoadStatementWriteTask(statement, this, - bucket); - task.run(); - if (!task.isDone()) { - logger.warn("statement write task unexpected not done"); - throw new ObDirectLoadUnexpectedException("statement write task unexpected not done"); - } - if (!task.isSuccess()) { - throw task.cause(); + writingCount.incrementAndGet(); + try { + ObDirectLoadStatementPromiseTask task = new ObDirectLoadStatementWriteTask(statement, + this, bucket); + task.run(); + if (!task.isDone()) { + logger.warn("statement write task unexpected not done"); + throw new ObDirectLoadUnexpectedException( + "statement write task unexpected not done"); + } + if (!task.isSuccess()) { + throw task.cause(); + } + } catch (ObDirectLoadException e) { + throw e; + } finally { + writingCount.decrementAndGet(); } } diff --git a/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementWriteTask.java b/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementWriteTask.java index 03931913..d1d6bde2 100644 --- a/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementWriteTask.java +++ b/src/main/java/com/alipay/oceanbase/rpc/direct_load/execution/ObDirectLoadStatementWriteTask.java @@ -25,7 +25,10 @@ import com.alipay.oceanbase.rpc.direct_load.exception.ObDirectLoadException; import com.alipay.oceanbase.rpc.direct_load.exception.ObDirectLoadExceptionUtil; import com.alipay.oceanbase.rpc.direct_load.future.ObDirectLoadStatementPromiseTask; +import com.alipay.oceanbase.rpc.direct_load.protocol.payload.ObDirectLoadGetStatusRpc; import com.alipay.oceanbase.rpc.direct_load.protocol.payload.ObDirectLoadInsertRpc; +import com.alipay.oceanbase.rpc.direct_load.protocol.payload.ObTableLoadClientStatus; +import com.alipay.oceanbase.rpc.protocol.payload.ResultCodes; import com.alipay.oceanbase.rpc.table.ObTable; import com.alipay.oceanbase.rpc.direct_load.protocol.ObDirectLoadProtocol; import com.alipay.oceanbase.rpc.util.ObByteBuf; @@ -87,6 +90,9 @@ private void sendInsert(ObTable table, ObByteBuf payloadBuffer) throws ObDirectL } catch (ObDirectLoadException e) { logger.warn("statement send insert failed, retry after " + retryInterval + "s, retryCount:" + retryCount, e); + + // 查询服务端状态决定是否重试 + checkStatus(); // 忽略所有发送失败错误码, 重试到任务状态为FAIL ++retryCount; try { @@ -103,6 +109,35 @@ private void sendInsert(ObTable table, ObByteBuf payloadBuffer) throws ObDirectL } } + private void checkStatus() throws ObDirectLoadException { + ObDirectLoadGetStatusRpc rpc = null; + try { + rpc = doGetStatus(); + } catch (ObDirectLoadException e) { + logger.warn("statement send get status rpc failed", e); + throw e; + } + + ObTableLoadClientStatus status = rpc.getStatus(); + int errorCode = rpc.getErrorCode(); + switch (status) { + case RUNNING: + break; + case ERROR: + logger.warn("statement server status is error, errorCode:" + errorCode); + throw ObDirectLoadExceptionUtil.convertException(status, errorCode); + case ABORT: + logger.warn("statement server status is abort, errorCode:" + errorCode); + if (errorCode == ResultCodes.OB_SUCCESS.errorCode) { + errorCode = ResultCodes.OB_CANCELED.errorCode; + } + throw ObDirectLoadExceptionUtil.convertException(status, errorCode); + default: + logger.warn("statement server status is unexpected, status:" + status); + throw ObDirectLoadExceptionUtil.convertException(status, errorCode); + } + } + private ObDirectLoadInsertRpc doSendInsert(ObTable table, ObByteBuf payloadBuffer, long timeoutMillis) throws ObDirectLoadException { // send insert rpc @@ -119,4 +154,20 @@ private ObDirectLoadInsertRpc doSendInsert(ObTable table, ObByteBuf payloadBuffe return rpc; } + private ObDirectLoadGetStatusRpc doGetStatus() throws ObDirectLoadException { + final ObTable table = statement.getObTablePool().getControlObTable(); + final long timeoutMillis = statement.getTimeoutRemain(); + + ObDirectLoadGetStatusRpc rpc = protocol.getGetStatusRpc(executor.getTraceId()); + rpc.setSvrAddr(executor.getSvrAddr()); + rpc.setTableId(executor.getTableId()); + rpc.setTaskId(executor.getTaskId()); + + logger.debug("statement send get status rpc, arg:" + rpc.getArg()); + connection.executeWithConnection(rpc, table, timeoutMillis); + logger.debug("statement get status rpc response successful, res:" + rpc.getRes()); + + return rpc; + } + }