Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,29 @@ public class RssSparkConfig {
.defaultValue(1.0d)
.withDescription(
"The buffer size to spill when spill triggered by config spark.rss.writer.buffer.spill.size");

public static final ConfigOption<Double> RSS_WRITER_MAX_ALLOCATED_MEMORY_RATIO =
ConfigOptions.key("rss.writer.maxAllocatedMemoryRatio")
.doubleType()
.checkValue(
ConfigUtils.DOUBLE_VALIDATOR_ZERO_TO_ONE,
"The 'rss.writer.maxAllocatedMemoryRatio' must be between 0.0 and 1.0")
.defaultValue(0.0d)
.withDescription(
"Max fraction of spark.executor.memory for shuffle write buffer allocated bytes; "
+ "waits instead of growing past this. 0 or negative disables.");

public static final ConfigOption<Long> RSS_WRITER_MAX_ALLOCATED_WAIT_TIMEOUT_MS =
ConfigOptions.key("rss.writer.maxAllocatedWaitTimeoutMillis")
.longType()
.checkValue(
ConfigUtils.POSITIVE_LONG_VALIDATOR,
"The 'rss.writer.maxAllocatedWaitTimeoutMillis' must be positive")
.defaultValue(10 * 60 * 1000L)
.withDescription(
"Timeout when waiting for allocated shuffle buffer memory to drop below the "
+ "ratio cap (e.g. slow remote push). Unit is milliseconds; default 10 minutes.");

public static final ConfigOption<Integer> RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM =
ConfigOptions.key("rss.client.reassign.maxReassignServerNum")
.intType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ public class BufferManagerOptions {
private long requireMemoryInterval;
private int requireMemoryRetryMax;
private double bufferSpillPercent;
/** Long.MAX_VALUE means no cap (disabled). */
private long maxAllocatedBytesLimit;

private long maxAllocatedWaitTimeoutMs;

public BufferManagerOptions(SparkConf sparkConf) {
bufferSize =
Expand Down Expand Up @@ -64,6 +68,26 @@ public BufferManagerOptions(SparkConf sparkConf) {
RssSparkConfig.RSS_WRITER_PRE_ALLOCATED_BUFFER_SIZE.defaultValue().get());
requireMemoryInterval = sparkConf.get(RssSparkConfig.RSS_WRITER_REQUIRE_MEMORY_INTERVAL);
requireMemoryRetryMax = sparkConf.get(RssSparkConfig.RSS_WRITER_REQUIRE_MEMORY_RETRY_MAX);

double maxAllocatedRatio =
sparkConf.getDouble(
RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_MEMORY_RATIO.key(),
RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_MEMORY_RATIO.defaultValue());

String executorMemoryKey = "spark.executor.memory";
if (maxAllocatedRatio > 0d && sparkConf.contains(executorMemoryKey)) {
double ratio = Math.min(maxAllocatedRatio, 1.0d);
this.maxAllocatedBytesLimit = (long) (sparkConf.getSizeAsBytes(executorMemoryKey) * ratio);
if (this.maxAllocatedBytesLimit <= 0) {
this.maxAllocatedBytesLimit = Long.MAX_VALUE;
}
} else {
this.maxAllocatedBytesLimit = Long.MAX_VALUE;
}
this.maxAllocatedWaitTimeoutMs =
sparkConf.getLong(
RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_WAIT_TIMEOUT_MS.key(),
RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_WAIT_TIMEOUT_MS.defaultValue());
if (LOG.isDebugEnabled()) {
LOG.debug(
"New buffer manager options, bufferSize: {}, bufferSpillThreshold: {}, preAllocatedBufferSize: {}",
Expand Down Expand Up @@ -138,4 +162,15 @@ public long getRequireMemoryInterval() {
public int getRequireMemoryRetryMax() {
return requireMemoryRetryMax;
}

/**
* @return max {@code allocatedBytes} for WriteBufferManager; {@link Long#MAX_VALUE} if disabled
*/
public long getMaxAllocatedBytesLimit() {
return maxAllocatedBytesLimit;
}

public long getMaxAllocatedWaitTimeoutMs() {
return maxAllocatedWaitTimeoutMs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.ThreadUtils;

public class WriteBufferManager extends MemoryConsumer {

Expand Down Expand Up @@ -113,6 +114,10 @@ public class WriteBufferManager extends MemoryConsumer {
private ShuffleServerPushCostTracker shuffleServerPushCostTracker;
// whether to use deferred compression for shuffle blocks
private final boolean isDeferredCompression;
/** Long.MAX_VALUE disables the cap. */
private final long maxAllocatedBytesLimit;

private final long maxAllocatedWaitTimeoutMs;

public WriteBufferManager(
int shuffleId,
Expand Down Expand Up @@ -210,6 +215,8 @@ public WriteBufferManager(
this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
this.stageAttemptNumber = stageAttemptNumber;
this.shuffleServerPushCostTracker = new ShuffleServerPushCostTracker();
this.maxAllocatedBytesLimit = bufferManagerOptions.getMaxAllocatedBytesLimit();
this.maxAllocatedWaitTimeoutMs = bufferManagerOptions.getMaxAllocatedWaitTimeoutMs();
}

public WriteBufferManager(
Expand Down Expand Up @@ -530,10 +537,34 @@ private void requestMemory(long requiredMem) {
}

private void requestExecutorMemory(long leastMem) {
long gotMem = acquireMemory(askExecutorMemory);
allocatedBytes.addAndGet(gotMem);
int retry = 0;
while (allocatedBytes.get() - usedBytes.get() < leastMem) {
long gotMem = 0;
long maxAllocatedTimeoutStarted = System.currentTimeMillis();
while (retry <= requireMemoryRetryMax) {
// limit the max bytes requested to avoid OOM, and the max wait time to avoid waiting too long
while (allocatedBytes.get() >= maxAllocatedBytesLimit) {
// trigger push to remote shuffle server
spillFunc.apply(clear(bufferSpillRatio));
LOG.info(
"Allocated memory[{}] has reached the limit[{}], sleep and wait for memory to be released.",
allocatedBytes.get(),
maxAllocatedBytesLimit);
ThreadUtils.sleep(
requireMemoryInterval, "Interrupted when waiting for allocated memory to be released.");
if (System.currentTimeMillis() - maxAllocatedTimeoutStarted >= maxAllocatedWaitTimeoutMs) {
String errorMsg =
String.format(
"Waiting timeout due to the allocated memory has reached the limit, allocatedBytes: %d, maxAllocatedBytesLimit: %d",
allocatedBytes.get(), maxAllocatedBytesLimit);
throw new RssException(errorMsg);
}
}

gotMem = acquireMemory(askExecutorMemory);
allocatedBytes.addAndGet(gotMem);
if (allocatedBytes.get() - usedBytes.get() >= leastMem) {
return;
}
LOG.info(
"Can't get memory for now, sleep and try["
+ retry
Expand All @@ -543,33 +574,27 @@ private void requestExecutorMemory(long leastMem) {
+ gotMem
+ "] less than "
+ leastMem);
try {
Thread.sleep(requireMemoryInterval);
} catch (InterruptedException ie) {
throw new RssException("Interrupted when waiting for memory.", ie);
}
gotMem = acquireMemory(askExecutorMemory);
allocatedBytes.addAndGet(gotMem);
retry++;
if (retry > requireMemoryRetryMax) {
taskMemoryManager.showMemoryUsage();
String message =
"Can't get memory to cache shuffle data, request["
+ askExecutorMemory
+ "], got["
+ gotMem
+ "],"
+ " WriteBufferManager allocated["
+ allocatedBytes
+ "] task used["
+ used
+ "]. It may be caused by shuffle server is full of data"
+ " or consider to optimize 'spark.executor.memory',"
+ " 'spark.rss.writer.buffer.spill.size'.";
LOG.error(message);
throw new RssException(message);
}
retry += 1;
ThreadUtils.sleep(requireMemoryInterval, "Interrupted when waiting for memory.");
}

// retry exceeded, still can't get enough memory, log error and throw exception
taskMemoryManager.showMemoryUsage();
String message =
"Can't get memory to cache shuffle data, request["
+ askExecutorMemory
+ "], got["
+ gotMem
+ "],"
+ " WriteBufferManager allocated["
+ allocatedBytes
+ "] task used["
+ used
+ "]. It may be caused by shuffle server is full of data"
+ " or consider to optimize 'spark.executor.memory',"
+ " 'spark.rss.writer.buffer.spill.size'.";
LOG.error(message);
throw new RssException(message);
}

public void releaseBlockResource(ShuffleBlockInfo block) {
Expand Down Expand Up @@ -809,4 +834,16 @@ public void close() {
public long getUncompressedDataLen() {
return uncompressedDataLen;
}

public Optional<Codec> getCodec() {
return codec;
}

public void setAllocatedBytes(long allocatedBytes) {
this.allocatedBytes.set(allocatedBytes);
}

public void setUsedBytes(long usedBytes) {
this.usedBytes.set(usedBytes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
Expand All @@ -28,7 +29,6 @@
import java.util.stream.Stream;

import com.google.common.collect.Maps;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
Expand All @@ -50,6 +50,7 @@
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;

import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -128,7 +129,7 @@ private void addRecord(boolean compress, BlockIdLayout layout) throws IllegalAcc
conf.set(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY, String.valueOf(false));
}
WriteBufferManager wbm = createManager(conf);
Optional<Codec> codec = (Optional<Codec>) FieldUtils.readField(wbm, "codec", true);
Optional<Codec> codec = wbm.getCodec();
if (compress) {
Assertions.assertTrue(codec.isPresent());
} else {
Expand Down Expand Up @@ -620,6 +621,58 @@ public void spillByOwnWithSparkTaskMemoryManagerTest() {
assertEquals(2, fakedTaskMemoryManager.getSpilledCnt());
}

/**
* When {@code allocatedBytes} is already at the configured cap and spill does not release memory
* (e.g. remote push never completes), {@link WriteBufferManager} should time out waiting instead
* of blocking forever.
*/
@Test
public void requestMemoryThrowsWhenAllocatedOverCapAndNotReleased() throws Exception {
SparkConf conf = getConf();
conf.set("spark.executor.memory", "20k");
conf.set(RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_MEMORY_RATIO.key(), "0.5");
conf.set(RssSparkConfig.RSS_WRITER_MAX_ALLOCATED_WAIT_TIMEOUT_MS.key(), "1");
conf.set(RssSparkConfig.RSS_WRITER_REQUIRE_MEMORY_INTERVAL.key(), "1");

TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
long cap = bufferOptions.getMaxAllocatedBytesLimit();
Assertions.assertTrue(
cap > 0 && cap < Long.MAX_VALUE, "max allocated cap must be enabled and finite");

Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillWithoutReleasingMemory =
blocks -> Collections.singletonList(CompletableFuture.completedFuture(0L));

WriteBufferManager wbm =
new WriteBufferManager(
0,
"taskId_maxAllocatedCapTimeoutTest",
0,
bufferOptions,
new KryoSerializer(conf),
Maps.newHashMap(),
mockTaskMemoryManager,
new ShuffleWriteMetrics(),
RssSparkConfig.toRssConf(conf),
spillWithoutReleasingMemory,
0);

wbm.setAllocatedBytes(cap);
wbm.setUsedBytes(cap - 20L);

RssException thrown =
Assertions.assertThrows(
RssException.class,
() -> wbm.addRecord(0, "Key", "Value"),
"expected timeout while waiting for allocated memory to drop below cap");
Assertions.assertTrue(
thrown.getMessage().contains("Waiting timeout"),
"message should describe wait timeout: " + thrown.getMessage());
Assertions.assertTrue(
thrown.getMessage().contains("maxAllocatedBytesLimit"),
"message should include limit context: " + thrown.getMessage());
}

@Test
public void addFirstRecordWithLargeSizeTest() {
SparkConf conf = getConf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.common.exception.RssException;

public class ThreadUtils {
private static final Logger LOGGER = LoggerFactory.getLogger(ThreadUtils.class);
private static final ThreadMXBean THREAD_BEAN = ManagementFactory.getThreadMXBean();
Expand Down Expand Up @@ -224,4 +226,12 @@ public Thread newThread(final Runnable runnable) {
return t;
}
}

public static void sleep(long millis, String errorMessage) {
try {
Thread.sleep(millis);
} catch (Exception e) {
throw new RssException(errorMessage, e);
}
}
}
Loading