diff --git a/src/main/java/com/alipay/oceanbase/rpc/table/ObTableClientBatchOpsImpl.java b/src/main/java/com/alipay/oceanbase/rpc/table/ObTableClientBatchOpsImpl.java index a0648fb0..5ae35fe9 100644 --- a/src/main/java/com/alipay/oceanbase/rpc/table/ObTableClientBatchOpsImpl.java +++ b/src/main/java/com/alipay/oceanbase/rpc/table/ObTableClientBatchOpsImpl.java @@ -405,39 +405,50 @@ public void partitionExecute(ObTableOperationResult[] results, List subObTableOperationResults = subObTableBatchOperationResult .getResults(); - - if (subObTableOperationResults.size() < subOperations.getTableOperations().size()) { - // only one result when it across failed - // only one result when hkv puts - if (subObTableOperationResults.size() == 1) { - ObTableOperationResult subObTableOperationResult = subObTableOperationResults + if (returnOneResult) { + if (results[0] == null) { + results[0] = new ObTableOperationResult(); + } + ObTableOperationResult subObTableOperationResult = subObTableOperationResults .get(0); - subObTableOperationResult.setExecuteHost(subObTable.getIp()); - subObTableOperationResult.setExecutePort(subObTable.getPort()); - for (ObPair aSubOperationWithIndexList : subOperationWithIndexList) { - results[aSubOperationWithIndexList.getLeft()] = subObTableOperationResult; + subObTableOperationResult.setExecuteHost(subObTable.getIp()); + subObTableOperationResult.setExecutePort(subObTable.getPort()); + subObTableOperationResult.setAffectedRows(results[0].getAffectedRows() + subObTableOperationResult.getAffectedRows()); + results[0] = subObTableOperationResult; + } else { + if (subObTableOperationResults.size() < subOperations.getTableOperations().size()) { + // only one result when it across failed + // only one result when hkv puts + if (subObTableOperationResults.size() == 1) { + ObTableOperationResult subObTableOperationResult = subObTableOperationResults + .get(0); + subObTableOperationResult.setExecuteHost(subObTable.getIp()); + subObTableOperationResult.setExecutePort(subObTable.getPort()); + for (ObPair aSubOperationWithIndexList : subOperationWithIndexList) { + results[aSubOperationWithIndexList.getLeft()] = subObTableOperationResult; + } + } else { + // unexpected result found + throw new IllegalArgumentException( + "check batch operation result size error: operation size [" + + subOperations.getTableOperations().size() + "] result size [" + + subObTableOperationResults.size() + "]"); } } else { - // unexpected result found - throw new IllegalArgumentException( - "check batch operation result size error: operation size [" - + subOperations.getTableOperations().size() + "] result size [" - + subObTableOperationResults.size() + "]"); - } - } else { - if (subOperationWithIndexList.size() != subObTableOperationResults.size()) { - throw new ObTableUnexpectedException("check batch result error: partition " - + partId + " expect result size " - + subOperationWithIndexList.size() - + " actual result size " - + subObTableOperationResults.size()); - } - for (int i = 0; i < subOperationWithIndexList.size(); i++) { - ObTableOperationResult subObTableOperationResult = subObTableOperationResults - .get(i); - subObTableOperationResult.setExecuteHost(subObTable.getIp()); - subObTableOperationResult.setExecutePort(subObTable.getPort()); - results[subOperationWithIndexList.get(i).getLeft()] = subObTableOperationResult; + if (subOperationWithIndexList.size() != subObTableOperationResults.size()) { + throw new ObTableUnexpectedException("check batch result error: partition " + + partId + " expect result size " + + subOperationWithIndexList.size() + + " actual result size " + + subObTableOperationResults.size()); + } + for (int i = 0; i < subOperationWithIndexList.size(); i++) { + ObTableOperationResult subObTableOperationResult = subObTableOperationResults + .get(i); + subObTableOperationResult.setExecuteHost(subObTable.getIp()); + subObTableOperationResult.setExecutePort(subObTable.getPort()); + results[subOperationWithIndexList.get(i).getLeft()] = subObTableOperationResult; + } } } String endpoint = subObTable.getIp() + ":" + subObTable.getPort(); @@ -455,10 +466,26 @@ public ObTableBatchOperationResult executeInternal() throws Exception { if (tableName == null || tableName.isEmpty()) { throw new IllegalArgumentException("table name is null"); } - long start = System.currentTimeMillis(); List operations = batchOperation.getTableOperations(); - final ObTableOperationResult[] obTableOperationResults = new ObTableOperationResult[operations - .size()]; + if (operations.isEmpty()) { + throw new IllegalArgumentException("operations is empty"); + } + ObTableOperationType lastType = operations.get(0).getOperationType(); + if (returnOneResult + && !(batchOperation.isSameType() && (lastType == ObTableOperationType.INSERT + || lastType == ObTableOperationType.PUT + || lastType == ObTableOperationType.REPLACE || lastType == ObTableOperationType.DEL))) { + throw new IllegalArgumentException( + "returnOneResult only support multi-insert/put/replace/del"); + } + long start = System.currentTimeMillis(); + ObTableOperationResult[] obTableOperationResults = null; + if (returnOneResult) { + obTableOperationResults = new ObTableOperationResult[1]; + } else { + obTableOperationResults = new ObTableOperationResult[operations.size()]; + } + Map>>> partitions = partitionPrepare(); long getTableTime = System.currentTimeMillis(); final Map context = ThreadLocalMap.getContextMap(); @@ -466,7 +493,8 @@ public ObTableBatchOperationResult executeInternal() throws Exception { final ConcurrentTaskExecutor executor = new ConcurrentTaskExecutor(executorService, partitions.size()); for (final Map.Entry>>> entry : partitions - .entrySet()) { + .entrySet()) { + ObTableOperationResult[] finalObTableOperationResults = obTableOperationResults; executor.execute(new ConcurrentTask() { /* * Do task. @@ -475,7 +503,7 @@ public ObTableBatchOperationResult executeInternal() throws Exception { public void doTask() { try { ThreadLocalMap.transmitContextMap(context); - partitionExecute(obTableOperationResults, entry); + partitionExecute(finalObTableOperationResults, entry); } catch (Exception e) { logger.error(LCD.convert("01-00026"), e); executor.collectExceptions(e); @@ -541,7 +569,6 @@ public void doTask() { return batchOperationResult; } - /* * clear batch operations1 */ diff --git a/src/test/java/com/alipay/oceanbase/rpc/ObAtomicBatchOperationTest.java b/src/test/java/com/alipay/oceanbase/rpc/ObAtomicBatchOperationTest.java index 8c709a7c..5f9a4b38 100644 --- a/src/test/java/com/alipay/oceanbase/rpc/ObAtomicBatchOperationTest.java +++ b/src/test/java/com/alipay/oceanbase/rpc/ObAtomicBatchOperationTest.java @@ -19,6 +19,7 @@ import com.alipay.oceanbase.rpc.exception.ObTableDuplicateKeyException; import com.alipay.oceanbase.rpc.exception.ObTableException; +import com.alipay.oceanbase.rpc.mutation.Insert; import com.alipay.oceanbase.rpc.table.api.TableBatchOps; import com.alipay.oceanbase.rpc.util.ObTableClientTestUtil; import org.junit.After; @@ -29,6 +30,9 @@ import java.util.List; import java.util.Map; +import static com.alipay.oceanbase.rpc.mutation.MutationFactory.colVal; +import static com.alipay.oceanbase.rpc.mutation.MutationFactory.row; + public class ObAtomicBatchOperationTest { private static final int dataSetSize = 4; private static final String successKey = "abc-5"; @@ -52,7 +56,7 @@ public void setup() throws Exception { String key = "abc-" + i; String val = "xyz-" + i; this.obTableClient.insert("test_varchar_table", key, new String[] { "c2" }, - new String[] { val }); + new String[] { val }); } } @@ -213,10 +217,8 @@ public void testReturnOneRes() { batchOps.insert("abcd-8", new String[] { "c2" }, new String[] { "returnOne-8" }); batchOps.insert("abcd-9", new String[] { "c2" }, new String[] { "returnOne-9" }); List results = batchOps.execute(); - Assert.assertEquals(results.size(), 3); + Assert.assertEquals(results.size(), 1); Assert.assertEquals(results.get(0), 3L); - Assert.assertEquals(results.get(1), 3L); - Assert.assertEquals(results.get(2), 3L); } catch (Exception ex) { Assert.assertTrue(ex instanceof ObTableException); } @@ -230,10 +232,8 @@ public void testReturnOneRes() { batchOps.insert("abcd-5", new String[] { "c2" }, new String[] { "returnOne-5" }); batchOps.insert("abcd-6", new String[] { "c2" }, new String[] { "returnOne-6" }); List results = batchOps.execute(); - Assert.assertEquals(results.size(), 3); + Assert.assertEquals(results.size(), 1); Assert.assertEquals(results.get(0), 3L); - Assert.assertEquals(results.get(1), 3L); - Assert.assertEquals(results.get(2), 3L); } catch (Exception ex) { Assert.assertTrue(ex instanceof ObTableException); }