Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,19 @@

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.AutomatedTestBase;

public abstract class FedWorkerBase {
protected static final Log LOG = LogFactory.getLog(FedWorkerBase.class.getName());

/** Upper bound (ms) for {@link #awaitCompressed(long)} polling against async worker-side compression. */
protected static final int COMPRESS_TIMEOUT_MS = 10_000;

/** Poll interval used by {@link #awaitCompressed(long)} between successive reads. */
private static final int COMPRESS_POLL_INTERVAL_MS = 25;

private final InetSocketAddress addr;
public final int port;

Expand Down Expand Up @@ -70,6 +77,38 @@ public MatrixBlock getMatrixBlock(long id) {
return FederatedTestUtils.getMatrixBlock(id, addr);
}

/**
* Poll the federated worker until the matrix at {@code id} is observed as a
* {@link CompressedMatrixBlock}, or {@link #COMPRESS_TIMEOUT_MS} elapses.
*
* <p>Federated workers compress asynchronously after a PUT/READ_VAR (see
* {@code CompressedMatrixBlockFactory.compressAsync}), so a {@code getMatrixBlock} fired right
* after the operation can race against the in-flight compression and return the uncompressed
* block. Tests that need to observe the compressed form should poll instead of sleeping a fixed
* amount.
*
* <p>On timeout this returns the most recent (uncompressed) read so the caller can produce a
* meaningful assertion failure naming the variable.
*
* @param id federated variable id
* @return the matrix block, compressed if compression finished in time, otherwise the latest read
*/
public MatrixBlock awaitCompressed(long id) {
final long deadline = System.currentTimeMillis() + COMPRESS_TIMEOUT_MS;
MatrixBlock mb = getMatrixBlock(id);
while(!(mb instanceof CompressedMatrixBlock) && System.currentTimeMillis() < deadline) {
try {
Thread.sleep(COMPRESS_POLL_INTERVAL_MS);
}
catch(InterruptedException ie) {
Thread.currentThread().interrupt();
fail("Interrupted while waiting for federated compression of id=" + id);
}
mb = getMatrixBlock(id);
}
return mb;
}

public long matrixMult(long idLeft, long idRight) {
return FederatedTestUtils.exec_MM(idLeft, idRight, addr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,16 @@ public void verifySameOrAlsoCompressedAsLocalCompress() {
// local
final MatrixBlock mbcLocal = CompressedMatrixBlockFactory.compress(mb).getLeft();

// federated
// federated. Compression on the worker is async; poll only when we expect compression to
// match the local result, otherwise a single read is enough.
final long id = putMatrixBlock(mb);
// give the federated site time to compress async.
FederatedTestUtils.wait(1000);
final MatrixBlock mbr = getMatrixBlock(id);
final MatrixBlock mbr = (mbcLocal instanceof CompressedMatrixBlock)
? awaitCompressed(id)
: getMatrixBlock(id);

if(mbcLocal instanceof CompressedMatrixBlock && !(mbr instanceof CompressedMatrixBlock))
fail("Invalid result, the federated site did not compress the matrix block");
fail("Invalid result, the federated site did not compress the matrix block within "
+ COMPRESS_TIMEOUT_MS + "ms");

TestUtils.compareMatricesBitAvgDistance(mbcLocal, mbr, 0, 0,
"Not equivalent matrix block returned from federated site");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,16 @@ public void verifySameOrAlsoCompressedAsLocalCompress() {
for(int i = 0; i < 9; i++) // chain left side compressed multiplications with idr.
ide = matrixMult(ide, idr);

// give the federated site time to compress async (it should already be done, but just to be safe).
FederatedTestUtils.wait(1000);

// Get back the matrix block stored behind mbr that should be compressed now.
final MatrixBlock mbr_compressed = getMatrixBlock(idr);
// Workload-driven compression runs async on the worker; poll instead of sleeping a fixed
// amount so a slow runner doesn't observe the still-uncompressed block.
final MatrixBlock mbr_compressed = awaitCompressed(idr);

if(!(mbr_compressed instanceof CompressedMatrixBlock))
fail("Invalid result, the federated site did not compress the matrix block based on workload");
fail("Invalid result, the federated site did not compress the matrix block based on workload within "
+ COMPRESS_TIMEOUT_MS + "ms");

TestUtils.compareMatricesBitAvgDistance(mbcLocal, mbr_compressed, 0, 0,
"Not equivalent matrix block returned from federated site");
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,14 @@ public FedWorkerReadMatrixCompress(int port, String path) {
public void verifyRead() {
MatrixBlock expected = readCSV();
Long id = readMatrix(path);
// give the federated site time to compress async.
FederatedTestUtils.wait(1000);
MatrixBlock actual = getMatrixBlock(id);
// Compression happens async on the worker; poll instead of sleeping a fixed amount.
MatrixBlock actual = awaitCompressed(id);
if(actual instanceof CompressedMatrixBlock){
TestUtils.compareMatricesBitAvgDistance(expected, actual, 0, 0,
"Not equivalent matrix block read from federated site");
}
else
fail("Did not compress the matrix input");
fail("Did not compress the matrix input within " + COMPRESS_TIMEOUT_MS + "ms");
}

protected MatrixBlock readCSV() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,4 @@ private static void exec(long id, String inst, InetSocketAddress addr, int timeo
fail("Failed to get response from put Matrix Block");
}
}

protected static void wait(int ms) {
try {
Thread.sleep(ms);
}
catch(Exception e) {
fail("Failed to wait");
}
}
}
Loading