diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java index 48414ec576..a2928acf20 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java @@ -177,6 +177,29 @@ public class RssSparkConfig { .defaultValue(1.0d) .withDescription( "The buffer size to spill when spill triggered by config spark.rss.writer.buffer.spill.size"); + + public static final ConfigOption RSS_WRITER_MAX_ALLOCATED_MEMORY_RATIO = + ConfigOptions.key("rss.writer.maxAllocatedMemoryRatio") + .doubleType() + .checkValue( + ConfigUtils.DOUBLE_VALIDATOR_ZERO_TO_ONE, + "The 'rss.writer.maxAllocatedMemoryRatio' must be between 0.0 and 1.0") + .defaultValue(0.0d) + .withDescription( + "Max fraction of spark.executor.memory for shuffle write buffer allocated bytes; " + + "waits instead of growing past this. 0 or negative disables."); + + public static final ConfigOption RSS_WRITER_MAX_ALLOCATED_WAIT_TIMEOUT_MS = + ConfigOptions.key("rss.writer.maxAllocatedWaitTimeoutMillis") + .longType() + .checkValue( + ConfigUtils.POSITIVE_LONG_VALIDATOR, + "The 'rss.writer.maxAllocatedWaitTimeoutMillis' must be positive") + .defaultValue(10 * 60 * 1000L) + .withDescription( + "Timeout when waiting for allocated shuffle buffer memory to drop below the " + + "ratio cap (e.g. slow remote push). Unit is milliseconds; default 10 minutes."); + public static final ConfigOption RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM = ConfigOptions.key("rss.client.reassign.maxReassignServerNum") .intType() diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java index 839e684a93..084fad5730 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/BufferManagerOptions.java @@ -36,6 +36,10 @@ public class BufferManagerOptions { private long requireMemoryInterval; private int requireMemoryRetryMax; private double bufferSpillPercent; + /** Long.MAX_VALUE means no cap (disabled). */ + private long maxAllocatedBytesLimit; + + private long maxAllocatedWaitTimeoutMs; public BufferManagerOptions(SparkConf sparkConf) { bufferSize = @@ -64,6 +68,26 @@ public BufferManagerOptions(SparkConf sparkConf) { RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.defaultValue().get()); requireMemoryInterval = sparkConf.get(RssSparkConfig.RSS_WRITER_REQUIRE_MEMORY_INTERVAL); requireMemoryRetryMax = sparkConf.get(RssSparkConfig.RSS_WRITER_REQUIRE_MEMORY_RETRY_MAX); + + double maxAllocatedRatio = + sparkConf.getDouble( + RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_MEMORY_RATIO.key(), + RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_MEMORY_RATIO.defaultValue()); + + String executorMemoryKey = "spark.executor.memory"; + if (maxAllocatedRatio > 0d && sparkConf.contains(executorMemoryKey)) { + double ratio = Math.min(maxAllocatedRatio, 1.0d); + this.maxAllocatedBytesLimit = (long) (sparkConf.getSizeAsBytes(executorMemoryKey) * ratio); + if (this.maxAllocatedBytesLimit <= 0) { + this.maxAllocatedBytesLimit = Long.MAX_VALUE; + } + } else { + this.maxAllocatedBytesLimit = Long.MAX_VALUE; + } + this.maxAllocatedWaitTimeoutMs = + sparkConf.getLong( + RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_WAIT_TIMEOUT_MS.key(), + RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_WAIT_TIMEOUT_MS.defaultValue()); if (LOG.isDebugEnabled()) { LOG.debug( "New buffer manager options, bufferSize: {}, bufferSpillThreshold: {}, preAllocatedBufferSize: {}", @@ -138,4 +162,15 @@ public long getRequireMemoryInterval() { public int getRequireMemoryRetryMax() { return requireMemoryRetryMax; } + + /** + * @return max {@code allocatedBytes} for WriteBufferManager; {@link Long#MAX_VALUE} if disabled + */ + public long getMaxAllocatedBytesLimit() { + return maxAllocatedBytesLimit; + } + + public long getMaxAllocatedWaitTimeoutMs() { + return maxAllocatedWaitTimeoutMs; + } } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index 280c6aefd8..bf27ec3221 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -59,6 +59,7 @@ import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.ChecksumUtils; +import org.apache.uniffle.common.util.ThreadUtils; public class WriteBufferManager extends MemoryConsumer { @@ -113,6 +114,10 @@ public class WriteBufferManager extends MemoryConsumer { private ShuffleServerPushCostTracker shuffleServerPushCostTracker; // whether to use deferred compression for shuffle blocks private final boolean isDeferredCompression; + /** Long.MAX_VALUE disables the cap. */ + private final long maxAllocatedBytesLimit; + + private final long maxAllocatedWaitTimeoutMs; public WriteBufferManager( int shuffleId, @@ -210,6 +215,8 @@ public WriteBufferManager( this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc; this.stageAttemptNumber = stageAttemptNumber; this.shuffleServerPushCostTracker = new ShuffleServerPushCostTracker(); + this.maxAllocatedBytesLimit = bufferManagerOptions.getMaxAllocatedBytesLimit(); + this.maxAllocatedWaitTimeoutMs = bufferManagerOptions.getMaxAllocatedWaitTimeoutMs(); } public WriteBufferManager( @@ -530,10 +537,34 @@ private void requestMemory(long requiredMem) { } private void requestExecutorMemory(long leastMem) { - long gotMem = acquireMemory(askExecutorMemory); - allocatedBytes.addAndGet(gotMem); int retry = 0; - while (allocatedBytes.get() - usedBytes.get() < leastMem) { + long gotMem = 0; + long maxAllocatedTimeoutStarted = System.currentTimeMillis(); + while (retry <= requireMemoryRetryMax) { + // limit the max bytes requested to avoid OOM, and the max wait time to avoid waiting too long + while (allocatedBytes.get() >= maxAllocatedBytesLimit) { + // trigger push to remote shuffle server + spillFunc.apply(clear(bufferSpillRatio)); + LOG.info( + "Allocated memory[{}] has reached the limit[{}], sleep and wait for memory to be released.", + allocatedBytes.get(), + maxAllocatedBytesLimit); + ThreadUtils.sleep( + requireMemoryInterval, "Interrupted when waiting for allocated memory to be released."); + if (System.currentTimeMillis() - maxAllocatedTimeoutStarted >= maxAllocatedWaitTimeoutMs) { + String errorMsg = + String.format( + "Waiting timeout due to the allocated memory has reached the limit, allocatedBytes: %d, maxAllocatedBytesLimit: %d", + allocatedBytes.get(), maxAllocatedBytesLimit); + throw new RssException(errorMsg); + } + } + + gotMem = acquireMemory(askExecutorMemory); + allocatedBytes.addAndGet(gotMem); + if (allocatedBytes.get() - usedBytes.get() >= leastMem) { + return; + } LOG.info( "Can't get memory for now, sleep and try[" + retry @@ -543,33 +574,27 @@ private void requestExecutorMemory(long leastMem) { + gotMem + "] less than " + leastMem); - try { - Thread.sleep(requireMemoryInterval); - } catch (InterruptedException ie) { - throw new RssException("Interrupted when waiting for memory.", ie); - } - gotMem = acquireMemory(askExecutorMemory); - allocatedBytes.addAndGet(gotMem); - retry++; - if (retry > requireMemoryRetryMax) { - taskMemoryManager.showMemoryUsage(); - String message = - "Can't get memory to cache shuffle data, request[" - + askExecutorMemory - + "], got[" - + gotMem - + "]," - + " WriteBufferManager allocated[" - + allocatedBytes - + "] task used[" - + used - + "]. It may be caused by shuffle server is full of data" - + " or consider to optimize 'spark.executor.memory'," - + " 'spark.rss.writer.buffer.spill.size'."; - LOG.error(message); - throw new RssException(message); - } + retry += 1; + ThreadUtils.sleep(requireMemoryInterval, "Interrupted when waiting for memory."); } + + // retry exceeded, still can't get enough memory, log error and throw exception + taskMemoryManager.showMemoryUsage(); + String message = + "Can't get memory to cache shuffle data, request[" + + askExecutorMemory + + "], got[" + + gotMem + + "]," + + " WriteBufferManager allocated[" + + allocatedBytes + + "] task used[" + + used + + "]. It may be caused by shuffle server is full of data" + + " or consider to optimize 'spark.executor.memory'," + + " 'spark.rss.writer.buffer.spill.size'."; + LOG.error(message); + throw new RssException(message); } public void releaseBlockResource(ShuffleBlockInfo block) { @@ -809,4 +834,16 @@ public void close() { public long getUncompressedDataLen() { return uncompressedDataLen; } + + public Optional getCodec() { + return codec; + } + + public void setAllocatedBytes(long allocatedBytes) { + this.allocatedBytes.set(allocatedBytes); + } + + public void setUsedBytes(long usedBytes) { + this.usedBytes.set(usedBytes); + } } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index 7ebce3c54f..937cbe9e41 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -28,7 +29,6 @@ import java.util.stream.Stream; import com.google.common.collect.Maps; -import org.apache.commons.lang3.reflect.FieldUtils; import org.apache.spark.SparkConf; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.MemoryConsumer; @@ -50,6 +50,7 @@ import org.apache.uniffle.common.compression.Codec; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.BlockIdLayout; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -128,7 +129,7 @@ private void addRecord(boolean compress, BlockIdLayout layout) throws IllegalAcc conf.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY, String.valueOf(false)); } WriteBufferManager wbm = createManager(conf); - Optional codec = (Optional) FieldUtils.readField(wbm, "codec", true); + Optional codec = wbm.getCodec(); if (compress) { Assertions.assertTrue(codec.isPresent()); } else { @@ -620,6 +621,58 @@ public void spillByOwnWithSparkTaskMemoryManagerTest() { assertEquals(2, fakedTaskMemoryManager.getSpilledCnt()); } + /** + * When {@code allocatedBytes} is already at the configured cap and spill does not release memory + * (e.g. remote push never completes), {@link WriteBufferManager} should time out waiting instead + * of blocking forever. + */ + @Test + public void requestMemoryThrowsWhenAllocatedOverCapAndNotReleased() throws Exception { + SparkConf conf = getConf(); + conf.set("spark.executor.memory", "20k"); + conf.set(RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_MEMORY_RATIO.key(), "0.5"); + conf.set(RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_WAIT_TIMEOUT_MS.key(), "1"); + conf.set(RssSparkConfig.RSS_WRITER_REQUIRE_MEMORY_INTERVAL.key(), "1"); + + TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class); + BufferManagerOptions bufferOptions = new BufferManagerOptions(conf); + long cap = bufferOptions.getMaxAllocatedBytesLimit(); + Assertions.assertTrue( + cap > 0 && cap < Long.MAX_VALUE, "max allocated cap must be enabled and finite"); + + Function, List>> spillWithoutReleasingMemory = + blocks -> Collections.singletonList(CompletableFuture.completedFuture(0L)); + + WriteBufferManager wbm = + new WriteBufferManager( + 0, + "taskId_maxAllocatedCapTimeoutTest", + 0, + bufferOptions, + new KryoSerializer(conf), + Maps.newHashMap(), + mockTaskMemoryManager, + new ShuffleWriteMetrics(), + RssSparkConfig.toRssConf(conf), + spillWithoutReleasingMemory, + 0); + + wbm.setAllocatedBytes(cap); + wbm.setUsedBytes(cap - 20L); + + RssException thrown = + Assertions.assertThrows( + RssException.class, + () -> wbm.addRecord(0, "Key", "Value"), + "expected timeout while waiting for allocated memory to drop below cap"); + Assertions.assertTrue( + thrown.getMessage().contains("Waiting timeout"), + "message should describe wait timeout: " + thrown.getMessage()); + Assertions.assertTrue( + thrown.getMessage().contains("maxAllocatedBytesLimit"), + "message should include limit context: " + thrown.getMessage()); + } + @Test public void addFirstRecordWithLargeSizeTest() { SparkConf conf = getConf(); diff --git a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java index 9f3c390987..44b130086f 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java +++ b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java @@ -40,6 +40,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.common.exception.RssException; + public class ThreadUtils { private static final Logger LOGGER = LoggerFactory.getLogger(ThreadUtils.class); private static final ThreadMXBean THREAD_BEAN = ManagementFactory.getThreadMXBean(); @@ -224,4 +226,12 @@ public Thread newThread(final Runnable runnable) { return t; } } + + public static void sleep(long millis, String errorMessage) { + try { + Thread.sleep(millis); + } catch (Exception e) { + throw new RssException(errorMessage, e); + } + } }