Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.file.AtomicMoveNotSupportedException;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
Expand Down Expand Up @@ -102,6 +103,16 @@ public void onMessage(WebSocket conn, String message) {
}
}

@Override
public void onMessage(WebSocket conn, ByteBuffer message) {
try {
AgentWsEnvelope.BinaryFileChunk chunk = AgentWsEnvelope.fromBinaryFileChunk(message);
handleFileChunk(conn, chunk.fileId(), chunk.chunkIndex(), chunk.payload());
} catch (Exception e) {
logger.warn("Failed handling startup WS binary message: {}", e.getMessage(), e);
}
}

@Override
public void onClose(WebSocket conn, int code, String reason, boolean remote) {
logger.info("Startup WS connection closed code={} reason={} remote={}", code, reason, remote);
Expand Down Expand Up @@ -228,28 +239,38 @@ private synchronized void handleFileOffer(WebSocket conn, AgentWsEnvelope envelo
}

private synchronized void handleFileChunk(WebSocket conn, AgentWsEnvelope envelope) {
if (currentFileStream == null || currentFileId == null || !currentFileId.equals(envelope.getFileId())) {
sendFileAck(conn, envelope.getFileId(), envelope.getChunkIndex(), AckStatus.failed, "file_offer_not_found");
return;
}

try {
byte[] payload = envelope.getChunkData() == null || envelope.getChunkData().isEmpty()
? new byte[0]
: Base64.getDecoder().decode(envelope.getChunkData());
handleFileChunk(conn, envelope.getFileId(), envelope.getChunkIndex(), payload);
} catch (Exception e) {
closeCurrentFileQuietly();
sendFileAck(conn, envelope.getFileId(), envelope.getChunkIndex(), AckStatus.failed, e.getMessage());
harnessJarFuture.completeExceptionally(e);
}
}

private synchronized void handleFileChunk(WebSocket conn, String fileId, Integer chunkIndex, byte[] payload) {
if (currentFileStream == null || currentFileId == null || !currentFileId.equals(fileId)) {
sendFileAck(conn, fileId, chunkIndex, AckStatus.failed, "file_offer_not_found");
return;
}

try {
currentFileStream.write(payload);
receivedBytes += payload.length;
receivedChunks++;
sendFileAck(conn, envelope.getFileId(), envelope.getChunkIndex(), AckStatus.chunk_received, null);
sendFileAck(conn, fileId, chunkIndex, AckStatus.chunk_received, null);

if (expectedBytes > 0 && receivedBytes >= expectedBytes) {
File completedFile = finalizeHarnessJar();
sendFileAck(conn, envelope.getFileId(), envelope.getChunkIndex(), AckStatus.complete, null);
sendFileAck(conn, fileId, chunkIndex, AckStatus.complete, null);
harnessJarFuture.complete(completedFile);
}
} catch (Exception e) {
closeCurrentFileQuietly();
sendFileAck(conn, envelope.getFileId(), envelope.getChunkIndex(), AckStatus.failed, e.getMessage());
sendFileAck(conn, fileId, chunkIndex, AckStatus.failed, e.getMessage());
harnessJarFuture.completeExceptionally(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.file.AtomicMoveNotSupportedException;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -105,7 +106,7 @@ public void onMessage(WebSocket connection, String message) {
case command -> handleCommand(connection, envelope);
case ping -> handlePing(connection, envelope);
case job_config -> handleJobConfig(envelope);
case file_offer -> handleFileOffer(envelope);
case file_offer -> handleFileOffer(connection, envelope);
case file_chunk -> handleFileChunk(connection, envelope);
case ack -> LOG.info(new ObjectMessage(Map.of("Message", "Received WS ack: " + envelope.getAckForType())));
case close -> {
Expand All @@ -119,6 +120,16 @@ public void onMessage(WebSocket connection, String message) {
}
}

@Override
public void onMessage(WebSocket connection, ByteBuffer message) {
try {
AgentWsEnvelope.BinaryFileChunk chunk = AgentWsEnvelope.fromBinaryFileChunk(message);
handleFileChunk(connection, chunk.fileId(), chunk.chunkIndex(), chunk.payload());
} catch (IOException e) {
LOG.warn(new ObjectMessage(Map.of("Message", "Failed parsing WS binary file chunk: " + e.getMessage())));
}
}

private void handleCommand(WebSocket connection, AgentWsEnvelope envelope) {
String commandId = envelope.getCommandId();
String command = envelope.getCommand();
Expand Down Expand Up @@ -182,7 +193,7 @@ private void handleJobConfig(AgentWsEnvelope envelope) {
}
}

private void handleFileOffer(AgentWsEnvelope envelope) {
private void handleFileOffer(WebSocket connection, AgentWsEnvelope envelope) {
String fileId = envelope.getFileId();
if (fileId == null || envelope.getFileName() == null) {
return;
Expand Down Expand Up @@ -212,48 +223,59 @@ private void handleFileOffer(AgentWsEnvelope envelope) {
new FileOutputStream(tempFile)
);
incomingFiles.put(fileId, state);
sendFileAck(connection, fileId, 0, AckStatus.ok, null);
} catch (Exception e) {
LOG.warn(new ObjectMessage(Map.of("Message", "Failed handling file_offer " + envelope.getFileName() + ": " + e.getMessage())));
sendFileAck(connection, fileId, 0, AckStatus.failed, e.getMessage());
}
}

private void handleFileChunk(WebSocket connection, AgentWsEnvelope envelope) {
String fileId = envelope.getFileId();
byte[] payload;
try {
payload = envelope.getChunkData() == null || envelope.getChunkData().isEmpty()
? new byte[0]
: Base64.getDecoder().decode(envelope.getChunkData());
} catch (Exception e) {
sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.failed, e.getMessage());
return;
}
handleFileChunk(connection, fileId, envelope.getChunkIndex(), payload);
}

private void handleFileChunk(WebSocket connection, String fileId, Integer chunkIndex, byte[] payload) {
IncomingFileState state = fileId != null ? incomingFiles.get(fileId) : null;
if (state == null) {
sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.failed, "file_offer_not_found");
sendFileAck(connection, fileId, chunkIndex, AckStatus.failed, "file_offer_not_found");
return;
}

try {
byte[] payload = envelope.getChunkData() == null || envelope.getChunkData().isEmpty()
? new byte[0]
: Base64.getDecoder().decode(envelope.getChunkData());

if (state.totalChunks < 0 && payload.length == 0) {
finalizeFile(state);
incomingFiles.remove(fileId);
sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.complete, null);
markFileComplete(connection, fileId, envelope.getChunkIndex());
sendFileAck(connection, fileId, chunkIndex, AckStatus.complete, null);
markFileComplete(connection, fileId, chunkIndex);
return;
}

state.outputStream.write(payload);
state.receivedBytes += payload.length;
state.receivedChunks++;

sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.chunk_received, null);
sendFileAck(connection, fileId, chunkIndex, AckStatus.chunk_received, null);

if (state.totalChunks >= 0 && state.receivedChunks >= state.totalChunks) {
finalizeFile(state);
incomingFiles.remove(fileId);
sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.complete, null);
markFileComplete(connection, fileId, envelope.getChunkIndex());
sendFileAck(connection, fileId, chunkIndex, AckStatus.complete, null);
markFileComplete(connection, fileId, chunkIndex);
}
} catch (Exception e) {
incomingFiles.remove(fileId);
state.closeQuietly();
sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.failed, e.getMessage());
sendFileAck(connection, fileId, chunkIndex, AckStatus.failed, e.getMessage());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import com.intuit.tank.vm.vmManager.models.CloudVmStatus;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;

@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
Expand All @@ -17,6 +19,7 @@ public class AgentWsEnvelope {
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);

public static final int PROTOCOL_VERSION = 1;
private static final int BINARY_FILE_CHUNK_MAGIC = 0x54574331;

public enum Type {
hello, command, ack, ping, pong, close,
Expand Down Expand Up @@ -233,6 +236,55 @@ public static AgentWsEnvelope fromJson(String json) throws IOException {
return MAPPER.readValue(json, AgentWsEnvelope.class);
}

public static ByteBuffer binaryFileChunk(String fileId, int chunkIndex, byte[] bytes, int offset, int len)
throws IOException {
if (fileId == null || fileId.isBlank()) {
throw new IOException("fileId is required for binary file chunk");
}
if (bytes == null) {
bytes = new byte[0];
}
if (offset < 0 || len < 0 || offset + len > bytes.length) {
throw new IOException("Invalid binary file chunk range");
}
byte[] fileIdBytes = fileId.getBytes(StandardCharsets.UTF_8);
if (fileIdBytes.length > 0xFFFF) {
throw new IOException("fileId too long for binary file chunk");
}
ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES + Integer.BYTES + Short.BYTES + fileIdBytes.length + len);
buffer.putInt(BINARY_FILE_CHUNK_MAGIC);
buffer.putInt(chunkIndex);
buffer.putShort((short) fileIdBytes.length);
buffer.put(fileIdBytes);
buffer.put(bytes, offset, len);
buffer.flip();
return buffer;
}

public static BinaryFileChunk fromBinaryFileChunk(ByteBuffer message) throws IOException {
ByteBuffer buffer = message.slice();
if (buffer.remaining() < Integer.BYTES + Integer.BYTES + Short.BYTES) {
throw new IOException("Invalid binary file chunk header");
}
int magic = buffer.getInt();
if (magic != BINARY_FILE_CHUNK_MAGIC) {
throw new IOException("Unsupported binary file chunk frame");
}
int chunkIndex = buffer.getInt();
int fileIdLength = Short.toUnsignedInt(buffer.getShort());
if (fileIdLength <= 0 || fileIdLength > buffer.remaining()) {
throw new IOException("Invalid binary file chunk fileId length");
}
byte[] fileIdBytes = new byte[fileIdLength];
buffer.get(fileIdBytes);
byte[] payload = new byte[buffer.remaining()];
buffer.get(payload);
return new BinaryFileChunk(new String(fileIdBytes, StandardCharsets.UTF_8), chunkIndex, payload);
}

public record BinaryFileChunk(String fileId, int chunkIndex, byte[] payload) {
}

// Factory methods for common frame types

public static AgentWsEnvelope hello(String instanceId, String jobId, String agentSessionId, String lastAppliedCommandId) {
Expand Down
Loading