Skip to content

Commit d74a9bb

Browse files
committed
Add OOC Memory Tracking
1 parent 6fe4fa5 commit d74a9bb

11 files changed

Lines changed: 1849 additions & 2 deletions

File tree

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,7 @@ else if(err instanceof Exception)
672672
List<OOCStream.QueueCallback<IndexedMatrixValue>> outList = new ArrayList<>(r.size());
673673
for(int j = 0; j < r.size(); j++) {
674674
if(explicitCaching[j]) {
675-
// Early forget item from cache
676-
outList.add(new OOCStream.SimpleQueueCallback<>(r.get(j).get(), null));
675+
outList.add(r.get(j).keepOpen());
677676
}
678677
else {
679678
outList.add(r.get(j).keepOpen());

src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
2727
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
2828
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
29+
import org.apache.sysds.runtime.ooc.memory.InMemoryQueueCallback;
2930
import org.apache.sysds.runtime.ooc.stats.OOCEventLog;
3031
import org.apache.sysds.utils.Statistics;
3132

@@ -130,6 +131,10 @@ public static void forget(long streamId, int blockId) {
130131
getCache().forget(key);
131132
}
132133

134+
public static void forget(BlockKey key) {
135+
getCache().forget(key);
136+
}
137+
133138
/**
134139
* Store a block in the OOC cache (serialize once)
135140
*/
@@ -195,6 +200,15 @@ public static CompletableFuture<OOCStream.QueueCallback<IndexedMatrixValue>> req
195200
return getCache().request(key).thenApply(e -> toCallback(e, key, null));
196201
}
197202

203+
public static OOCStream.QueueCallback<IndexedMatrixValue> tryRequestBlock(long streamId, long blockId) {
204+
return tryRequestBlock(new BlockKey(streamId, (int) blockId));
205+
}
206+
207+
public static OOCStream.QueueCallback<IndexedMatrixValue> tryRequestBlock(BlockKey key) {
208+
BlockEntry entry = getCache().tryRequest(key);
209+
return entry == null ? null : toCallback(entry, key, null);
210+
}
211+
198212
public static CompletableFuture<List<OOCStream.QueueCallback<IndexedMatrixValue>>> requestManyBlocks(List<BlockKey> keys) {
199213
return getCache().request(keys).thenApply(
200214
l -> {
@@ -245,6 +259,10 @@ public static boolean canClaimMemory() {
245259
return getCache().isWithinLimits() && OOCInstruction.getComputeInFlight() <= OOCInstruction.getComputeBackpressureThreshold();
246260
}
247261

262+
public static OOCCacheScheduler.HandoverHandle handover(BlockKey key, InMemoryQueueCallback callback) {
263+
return getCache().handover(key, callback);
264+
}
265+
248266
private static void pin(BlockEntry entry) {
249267
getCache().pin(entry);
250268
}

src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
package org.apache.sysds.runtime.ooc.cache;
2121

22+
import org.apache.sysds.runtime.instructions.ooc.OOCStream;
23+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
24+
import org.apache.sysds.runtime.ooc.memory.InMemoryQueueCallback;
25+
2226
import java.util.Collection;
2327
import java.util.List;
2428
import java.util.concurrent.CompletableFuture;
@@ -32,6 +36,17 @@ public interface OOCCacheScheduler {
3236
*/
3337
CompletableFuture<BlockEntry> request(BlockKey key);
3438

39+
/**
40+
* Tries to request a single block from the cache.
41+
* Immediately returns the entry if present, otherwise null without scheduling reads.
42+
* @param key the requested key associated to the block
43+
* @return the available BlockEntry or null
44+
*/
45+
default BlockEntry tryRequest(BlockKey key) {
46+
List<BlockEntry> out = tryRequest(List.of(key));
47+
return out == null || out.isEmpty() ? null : out.get(0);
48+
}
49+
3550
/**
3651
* Requests a list of blocks from the cache that must be available at the same time.
3752
* @param keys the requested keys associated to the block
@@ -81,6 +96,15 @@ public interface OOCCacheScheduler {
8196
*/
8297
BlockEntry putAndPin(BlockKey key, Object data, long size);
8398

99+
interface HandoverHandle {
100+
BlockKey getKey();
101+
boolean isCommitted();
102+
CompletableFuture<Boolean> getCompletionFuture();
103+
OOCStream.QueueCallback<IndexedMatrixValue> reclaim();
104+
}
105+
106+
HandoverHandle handover(BlockKey key, InMemoryQueueCallback callback);
107+
84108
/**
85109
* Places a new source-backed block in the cache and registers the location with the IO handler. The entry is
86110
* treated as backed by disk, so eviction does not schedule spill writes.

src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import org.apache.commons.logging.Log;
2424
import org.apache.commons.logging.LogFactory;
2525
import org.apache.sysds.api.DMLScript;
26+
import org.apache.sysds.runtime.instructions.ooc.OOCStream;
27+
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
28+
import org.apache.sysds.runtime.ooc.memory.InMemoryQueueCallback;
2629
import org.apache.sysds.runtime.ooc.stats.OOCEventLog;
2730
import org.apache.sysds.utils.Statistics;
2831
import scala.Tuple2;
@@ -49,6 +52,7 @@ public class OOCLRUCacheScheduler implements OOCCacheScheduler {
4952
private final HashMap<BlockKey, BlockEntry> _evictionCache;
5053
private final DeferredReadQueue _deferredReadRequests;
5154
private final Deque<DeferredReadRequest> _processingReadRequests;
55+
private final Deque<PendingHandover> _pendingHandovers;
5256
private final HashMap<BlockKey, BlockReadState> _blockReads;
5357
private volatile long _hardLimit;
5458
private long _evictionLimit;
@@ -74,6 +78,7 @@ public OOCLRUCacheScheduler(OOCIOHandler ioHandler, long evictionLimit, long har
7478
this._evictionCache = new HashMap<>();
7579
this._deferredReadRequests = new DeferredReadQueue();
7680
this._processingReadRequests = new ArrayDeque<>();
81+
this._pendingHandovers = new ArrayDeque<>();
7782
this._blockReads = new HashMap<>();
7883
this._hardLimit = hardLimit;
7984
this._evictionLimit = evictionLimit;
@@ -282,6 +287,25 @@ public BlockEntry putAndPin(BlockKey key, Object data, long size) {
282287
return put(key, data, size, true, null);
283288
}
284289

290+
@Override
291+
public HandoverHandle handover(BlockKey key, InMemoryQueueCallback callback) {
292+
if(!this._running)
293+
throw new IllegalStateException("Cache scheduler has been shut down.");
294+
PendingHandover handover = new PendingHandover(key, callback);
295+
boolean immediateCommit = false;
296+
synchronized(this) {
297+
if(canAcceptHandoverLocked(callback.getManagedBytes()))
298+
immediateCommit = true;
299+
else
300+
_pendingHandovers.addLast(handover);
301+
}
302+
if(immediateCommit) {
303+
if(commitHandover(handover))
304+
onCacheSizeChanged(true);
305+
}
306+
return handover;
307+
}
308+
285309
@Override
286310
public void putSourceBacked(BlockKey key, Object data, long size, OOCIOHandler.SourceBlockDescriptor descriptor) {
287311
put(key, data, size, false, descriptor);
@@ -487,6 +511,14 @@ public synchronized void shutdown() {
487511
_cache.clear();
488512
_evictionCache.clear();
489513
_processingReadRequests.clear();
514+
while(!_pendingHandovers.isEmpty()) {
515+
PendingHandover pending = _pendingHandovers.pollFirst();
516+
if(pending == null)
517+
continue;
518+
OOCStream.QueueCallback<IndexedMatrixValue> callback = pending.reclaim();
519+
if(callback != null)
520+
callback.close();
521+
}
490522
_deferredReadRequests.clear();
491523
_deferredReadCountHint = 0;
492524
_blockReads.clear();
@@ -555,6 +587,9 @@ private void onCacheSizeChangedInternal(boolean incr) {
555587
onCacheSizeIncremented();
556588
else
557589
while(onCacheSizeDecremented()) {}
590+
while(processPendingHandovers()) {
591+
onCacheSizeIncremented();
592+
}
558593
if(DMLScript.OOC_LOG_EVENTS)
559594
OOCEventLog.onCacheSizeChangedEvent(_callerId, System.nanoTime(), _cacheSize, _bytesUpForEviction,
560595
_pinnedBytes, _readingReservedBytes);
@@ -721,6 +756,32 @@ private long getEvictionPressure() {
721756
return _cacheSize + _readBuffer - _bytesUpForEviction;
722757
}
723758

759+
private boolean processPendingHandovers() {
760+
List<PendingHandover> committed = new ArrayList<>();
761+
synchronized(this) {
762+
while(!_pendingHandovers.isEmpty()) {
763+
PendingHandover pending = _pendingHandovers.peekFirst();
764+
if(pending == null)
765+
break;
766+
if(pending.isCancelled()) {
767+
_pendingHandovers.pollFirst();
768+
continue;
769+
}
770+
long bytes = pending.getManagedBytes();
771+
if(!canAcceptHandoverLocked(bytes))
772+
break;
773+
_pendingHandovers.pollFirst();
774+
committed.add(pending);
775+
}
776+
}
777+
boolean progress = false;
778+
for(PendingHandover pending : committed) {
779+
if(commitHandover(pending))
780+
progress = true;
781+
}
782+
return progress;
783+
}
784+
724785
private boolean onCacheSizeDecremented() {
725786
if(_cacheSize + 10000000 >= _hardLimit || _deferredReadCountHint == 0)
726787
return false;
@@ -1018,6 +1079,34 @@ private void registerWaiter(BlockKey key, DeferredReadRequest request, int index
10181079
state.waiters.add(new DeferredReadWaiter(request, index));
10191080
}
10201081

1082+
private boolean commitHandover(PendingHandover pending) {
1083+
InMemoryQueueCallback callback = pending.takeForCommit();
1084+
if(callback == null)
1085+
return false;
1086+
try {
1087+
IndexedMatrixValue value = callback.takeManagedResultForHandover();
1088+
long size = callback.getManagedBytes();
1089+
synchronized(this) {
1090+
BlockEntry entry = new BlockEntry(pending.getKey(), size, value);
1091+
_cache.put(pending.getKey(), entry);
1092+
_cacheSize += size;
1093+
}
1094+
callback.releaseManagedMemory();
1095+
callback.close();
1096+
pending.markCommitted();
1097+
return true;
1098+
}
1099+
catch(Throwable t) {
1100+
pending.markCancelled();
1101+
callback.close();
1102+
throw t;
1103+
}
1104+
}
1105+
1106+
private boolean canAcceptHandoverLocked(long bytes) {
1107+
return bytes >= 0 && _cacheSize + bytes <= _hardLimit;
1108+
}
1109+
10211110
private static class BlockReadState {
10221111
private double priority;
10231112
private final List<DeferredReadWaiter> waiters;
@@ -1037,4 +1126,76 @@ private DeferredReadWaiter(DeferredReadRequest request, int index) {
10371126
this.index = index;
10381127
}
10391128
}
1129+
1130+
private static class PendingHandover implements HandoverHandle {
1131+
private final BlockKey _key;
1132+
private final CompletableFuture<Boolean> _completionFuture;
1133+
private InMemoryQueueCallback _callback;
1134+
private boolean _committed;
1135+
private boolean _cancelled;
1136+
private boolean _committing;
1137+
1138+
private PendingHandover(BlockKey key, InMemoryQueueCallback callback) {
1139+
_key = key;
1140+
_completionFuture = new CompletableFuture<>();
1141+
_callback = callback;
1142+
}
1143+
1144+
@Override
1145+
public synchronized BlockKey getKey() {
1146+
return _key;
1147+
}
1148+
1149+
@Override
1150+
public synchronized boolean isCommitted() {
1151+
return _committed;
1152+
}
1153+
1154+
@Override
1155+
public synchronized CompletableFuture<Boolean> getCompletionFuture() {
1156+
return _completionFuture;
1157+
}
1158+
1159+
@Override
1160+
public synchronized OOCStream.QueueCallback<IndexedMatrixValue> reclaim() {
1161+
if(_committed || _committing)
1162+
return null;
1163+
_cancelled = true;
1164+
_completionFuture.complete(false);
1165+
OOCStream.QueueCallback<IndexedMatrixValue> callback = _callback;
1166+
_callback = null;
1167+
return callback;
1168+
}
1169+
1170+
private synchronized long getManagedBytes() {
1171+
return _callback == null ? 0 : _callback.getManagedBytes();
1172+
}
1173+
1174+
private synchronized boolean isCancelled() {
1175+
return _cancelled;
1176+
}
1177+
1178+
private synchronized InMemoryQueueCallback takeForCommit() {
1179+
if(_committed || _cancelled || _committing)
1180+
return null;
1181+
_committing = true;
1182+
InMemoryQueueCallback callback = _callback;
1183+
_callback = null;
1184+
return callback;
1185+
}
1186+
1187+
private synchronized void markCommitted() {
1188+
_committing = false;
1189+
_committed = true;
1190+
_completionFuture.complete(true);
1191+
}
1192+
1193+
private synchronized void markCancelled() {
1194+
if(_committed || _cancelled)
1195+
return;
1196+
_committing = false;
1197+
_cancelled = true;
1198+
_completionFuture.complete(false);
1199+
}
1200+
}
10401201
}

0 commit comments

Comments
 (0)