diff --git a/agent/agent_startup/src/main/java/com/intuit/tank/agent/StartupWebSocketServer.java b/agent/agent_startup/src/main/java/com/intuit/tank/agent/StartupWebSocketServer.java index 270db707e..4b95e4cca 100644 --- a/agent/agent_startup/src/main/java/com/intuit/tank/agent/StartupWebSocketServer.java +++ b/agent/agent_startup/src/main/java/com/intuit/tank/agent/StartupWebSocketServer.java @@ -12,6 +12,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.net.InetSocketAddress; +import java.nio.ByteBuffer; import java.nio.file.AtomicMoveNotSupportedException; import java.nio.file.Files; import java.nio.file.StandardCopyOption; @@ -102,6 +103,16 @@ public void onMessage(WebSocket conn, String message) { } } + @Override + public void onMessage(WebSocket conn, ByteBuffer message) { + try { + AgentWsEnvelope.BinaryFileChunk chunk = AgentWsEnvelope.fromBinaryFileChunk(message); + handleFileChunk(conn, chunk.fileId(), chunk.chunkIndex(), chunk.payload()); + } catch (Exception e) { + logger.warn("Failed handling startup WS binary message: {}", e.getMessage(), e); + } + } + @Override public void onClose(WebSocket conn, int code, String reason, boolean remote) { logger.info("Startup WS connection closed code={} reason={} remote={}", code, reason, remote); @@ -228,28 +239,38 @@ private synchronized void handleFileOffer(WebSocket conn, AgentWsEnvelope envelo } private synchronized void handleFileChunk(WebSocket conn, AgentWsEnvelope envelope) { - if (currentFileStream == null || currentFileId == null || !currentFileId.equals(envelope.getFileId())) { - sendFileAck(conn, envelope.getFileId(), envelope.getChunkIndex(), AckStatus.failed, "file_offer_not_found"); - return; - } - try { byte[] payload = envelope.getChunkData() == null || envelope.getChunkData().isEmpty() ? new byte[0] : Base64.getDecoder().decode(envelope.getChunkData()); + handleFileChunk(conn, envelope.getFileId(), envelope.getChunkIndex(), payload); + } catch (Exception e) { + closeCurrentFileQuietly(); + sendFileAck(conn, envelope.getFileId(), envelope.getChunkIndex(), AckStatus.failed, e.getMessage()); + harnessJarFuture.completeExceptionally(e); + } + } + + private synchronized void handleFileChunk(WebSocket conn, String fileId, Integer chunkIndex, byte[] payload) { + if (currentFileStream == null || currentFileId == null || !currentFileId.equals(fileId)) { + sendFileAck(conn, fileId, chunkIndex, AckStatus.failed, "file_offer_not_found"); + return; + } + + try { currentFileStream.write(payload); receivedBytes += payload.length; receivedChunks++; - sendFileAck(conn, envelope.getFileId(), envelope.getChunkIndex(), AckStatus.chunk_received, null); + sendFileAck(conn, fileId, chunkIndex, AckStatus.chunk_received, null); if (expectedBytes > 0 && receivedBytes >= expectedBytes) { File completedFile = finalizeHarnessJar(); - sendFileAck(conn, envelope.getFileId(), envelope.getChunkIndex(), AckStatus.complete, null); + sendFileAck(conn, fileId, chunkIndex, AckStatus.complete, null); harnessJarFuture.complete(completedFile); } } catch (Exception e) { closeCurrentFileQuietly(); - sendFileAck(conn, envelope.getFileId(), envelope.getChunkIndex(), AckStatus.failed, e.getMessage()); + sendFileAck(conn, fileId, chunkIndex, AckStatus.failed, e.getMessage()); harnessJarFuture.completeExceptionally(e); } } diff --git a/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketServer.java b/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketServer.java index 0c717b223..77f53491d 100644 --- a/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketServer.java +++ b/agent/apiharness/src/main/java/com/intuit/tank/harness/AgentCommandWebSocketServer.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.io.OutputStream; import java.net.InetSocketAddress; +import java.nio.ByteBuffer; import java.nio.file.AtomicMoveNotSupportedException; import java.nio.file.Files; import java.nio.file.Path; @@ -105,7 +106,7 @@ public void onMessage(WebSocket connection, String message) { case command -> handleCommand(connection, envelope); case ping -> handlePing(connection, envelope); case job_config -> handleJobConfig(envelope); - case file_offer -> handleFileOffer(envelope); + case file_offer -> handleFileOffer(connection, envelope); case file_chunk -> handleFileChunk(connection, envelope); case ack -> LOG.info(new ObjectMessage(Map.of("Message", "Received WS ack: " + envelope.getAckForType()))); case close -> { @@ -119,6 +120,16 @@ public void onMessage(WebSocket connection, String message) { } } + @Override + public void onMessage(WebSocket connection, ByteBuffer message) { + try { + AgentWsEnvelope.BinaryFileChunk chunk = AgentWsEnvelope.fromBinaryFileChunk(message); + handleFileChunk(connection, chunk.fileId(), chunk.chunkIndex(), chunk.payload()); + } catch (IOException e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Failed parsing WS binary file chunk: " + e.getMessage()))); + } + } + private void handleCommand(WebSocket connection, AgentWsEnvelope envelope) { String commandId = envelope.getCommandId(); String command = envelope.getCommand(); @@ -182,7 +193,7 @@ private void handleJobConfig(AgentWsEnvelope envelope) { } } - private void handleFileOffer(AgentWsEnvelope envelope) { + private void handleFileOffer(WebSocket connection, AgentWsEnvelope envelope) { String fileId = envelope.getFileId(); if (fileId == null || envelope.getFileName() == null) { return; @@ -212,29 +223,40 @@ private void handleFileOffer(AgentWsEnvelope envelope) { new FileOutputStream(tempFile) ); incomingFiles.put(fileId, state); + sendFileAck(connection, fileId, 0, AckStatus.ok, null); } catch (Exception e) { LOG.warn(new ObjectMessage(Map.of("Message", "Failed handling file_offer " + envelope.getFileName() + ": " + e.getMessage()))); + sendFileAck(connection, fileId, 0, AckStatus.failed, e.getMessage()); } } private void handleFileChunk(WebSocket connection, AgentWsEnvelope envelope) { String fileId = envelope.getFileId(); + byte[] payload; + try { + payload = envelope.getChunkData() == null || envelope.getChunkData().isEmpty() + ? new byte[0] + : Base64.getDecoder().decode(envelope.getChunkData()); + } catch (Exception e) { + sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.failed, e.getMessage()); + return; + } + handleFileChunk(connection, fileId, envelope.getChunkIndex(), payload); + } + + private void handleFileChunk(WebSocket connection, String fileId, Integer chunkIndex, byte[] payload) { IncomingFileState state = fileId != null ? incomingFiles.get(fileId) : null; if (state == null) { - sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.failed, "file_offer_not_found"); + sendFileAck(connection, fileId, chunkIndex, AckStatus.failed, "file_offer_not_found"); return; } try { - byte[] payload = envelope.getChunkData() == null || envelope.getChunkData().isEmpty() - ? new byte[0] - : Base64.getDecoder().decode(envelope.getChunkData()); - if (state.totalChunks < 0 && payload.length == 0) { finalizeFile(state); incomingFiles.remove(fileId); - sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.complete, null); - markFileComplete(connection, fileId, envelope.getChunkIndex()); + sendFileAck(connection, fileId, chunkIndex, AckStatus.complete, null); + markFileComplete(connection, fileId, chunkIndex); return; } @@ -242,18 +264,18 @@ private void handleFileChunk(WebSocket connection, AgentWsEnvelope envelope) { state.receivedBytes += payload.length; state.receivedChunks++; - sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.chunk_received, null); + sendFileAck(connection, fileId, chunkIndex, AckStatus.chunk_received, null); if (state.totalChunks >= 0 && state.receivedChunks >= state.totalChunks) { finalizeFile(state); incomingFiles.remove(fileId); - sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.complete, null); - markFileComplete(connection, fileId, envelope.getChunkIndex()); + sendFileAck(connection, fileId, chunkIndex, AckStatus.complete, null); + markFileComplete(connection, fileId, chunkIndex); } } catch (Exception e) { incomingFiles.remove(fileId); state.closeQuietly(); - sendFileAck(connection, fileId, envelope.getChunkIndex(), AckStatus.failed, e.getMessage()); + sendFileAck(connection, fileId, chunkIndex, AckStatus.failed, e.getMessage()); } } diff --git a/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelope.java b/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelope.java index d355315f8..3744b88bd 100644 --- a/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelope.java +++ b/api/src/main/java/com/intuit/tank/vm/agent/messages/AgentWsEnvelope.java @@ -8,6 +8,8 @@ import com.intuit.tank.vm.vmManager.models.CloudVmStatus; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; @JsonIgnoreProperties(ignoreUnknown = true) @JsonInclude(JsonInclude.Include.NON_NULL) @@ -17,6 +19,7 @@ public class AgentWsEnvelope { .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); public static final int PROTOCOL_VERSION = 1; + private static final int BINARY_FILE_CHUNK_MAGIC = 0x54574331; public enum Type { hello, command, ack, ping, pong, close, @@ -233,6 +236,55 @@ public static AgentWsEnvelope fromJson(String json) throws IOException { return MAPPER.readValue(json, AgentWsEnvelope.class); } + public static ByteBuffer binaryFileChunk(String fileId, int chunkIndex, byte[] bytes, int offset, int len) + throws IOException { + if (fileId == null || fileId.isBlank()) { + throw new IOException("fileId is required for binary file chunk"); + } + if (bytes == null) { + bytes = new byte[0]; + } + if (offset < 0 || len < 0 || offset + len > bytes.length) { + throw new IOException("Invalid binary file chunk range"); + } + byte[] fileIdBytes = fileId.getBytes(StandardCharsets.UTF_8); + if (fileIdBytes.length > 0xFFFF) { + throw new IOException("fileId too long for binary file chunk"); + } + ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES + Integer.BYTES + Short.BYTES + fileIdBytes.length + len); + buffer.putInt(BINARY_FILE_CHUNK_MAGIC); + buffer.putInt(chunkIndex); + buffer.putShort((short) fileIdBytes.length); + buffer.put(fileIdBytes); + buffer.put(bytes, offset, len); + buffer.flip(); + return buffer; + } + + public static BinaryFileChunk fromBinaryFileChunk(ByteBuffer message) throws IOException { + ByteBuffer buffer = message.slice(); + if (buffer.remaining() < Integer.BYTES + Integer.BYTES + Short.BYTES) { + throw new IOException("Invalid binary file chunk header"); + } + int magic = buffer.getInt(); + if (magic != BINARY_FILE_CHUNK_MAGIC) { + throw new IOException("Unsupported binary file chunk frame"); + } + int chunkIndex = buffer.getInt(); + int fileIdLength = Short.toUnsignedInt(buffer.getShort()); + if (fileIdLength <= 0 || fileIdLength > buffer.remaining()) { + throw new IOException("Invalid binary file chunk fileId length"); + } + byte[] fileIdBytes = new byte[fileIdLength]; + buffer.get(fileIdBytes); + byte[] payload = new byte[buffer.remaining()]; + buffer.get(payload); + return new BinaryFileChunk(new String(fileIdBytes, StandardCharsets.UTF_8), chunkIndex, payload); + } + + public record BinaryFileChunk(String fileId, int chunkIndex, byte[] payload) { + } + // Factory methods for common frame types public static AgentWsEnvelope hello(String instanceId, String jobId, String agentSessionId, String lastAppliedCommandId) { diff --git a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java index e7d25f7a2..22513b164 100644 --- a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java +++ b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java @@ -27,8 +27,6 @@ import java.nio.ByteBuffer; import java.nio.file.Files; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Optional; @@ -48,8 +46,10 @@ public class ControllerInitiatedAgentWsClient implements AgentWsCommandSender { private static final String SETTINGS_FILE_NAME = "settings.xml"; private static final String SCRIPT_FILE_NAME = "script.xml"; private static final String LOCAL_CONTROLLER_ORIGIN = "http://localhost:8080"; - private static final int DEFAULT_CHUNK_BYTES = 524288; - private static final int CHUNK_ACK_WINDOW = 4; + private static final int DEFAULT_CHUNK_BYTES = + Math.max(1, Integer.getInteger("tank.ws.chunkBytes", 2 * 1024 * 1024)); + private static final int CHUNK_ACK_WINDOW = + Math.max(1, Integer.getInteger("tank.ws.chunkAckWindow", 32)); private static final long MAX_BOOTSTRAP_CONNECTION_MS = Long.getLong("tank.ws.bootstrap.maxConnectionMs", 30_000L); @@ -436,7 +436,7 @@ private boolean sendFile(SessionContext context, String instanceId, String jobId if (content == null || content.length == 0) { pendingAcks.remove(fileId); - sendChunk(context, instanceId, jobId, fileId, 0, new byte[0], 0, connectionDeadlineMs); + sendChunk(context, instanceId, fileId, 0, new byte[0], 0, 0, connectionDeadlineMs); return true; } @@ -478,23 +478,21 @@ private boolean sendFile(SessionContext context, String instanceId, String jobId return false; } int len = Math.min(chunkBytes, content.length - offset); - sendChunk(context, instanceId, jobId, fileId, chunkIndex, - Arrays.copyOfRange(content, offset, offset + len), len, connectionDeadlineMs); + sendChunk(context, instanceId, fileId, chunkIndex, content, offset, len, connectionDeadlineMs); chunkIndex++; } return true; } - private void sendChunk(SessionContext context, String instanceId, String jobId, String fileId, - int chunkIndex, byte[] bytes, int len, long connectionDeadlineMs) + private void sendChunk(SessionContext context, String instanceId, String fileId, + int chunkIndex, byte[] bytes, int offset, int len, long connectionDeadlineMs) throws IOException, InterruptedException { CompletableFuture gate = null; if ((chunkIndex + 1) % CHUNK_ACK_WINDOW == 0) { gate = new CompletableFuture<>(); pendingChunkAcks.put(instanceId, new PendingChunkAck(fileId, chunkIndex, gate)); } - String base64 = len == 0 ? "" : Base64.getEncoder().encodeToString(bytes); - sendEnvelope(context, AgentWsEnvelope.fileChunk(instanceId, jobId, fileId, chunkIndex, base64)); + sendBinaryChunk(context, AgentWsEnvelope.binaryFileChunk(fileId, chunkIndex, bytes, offset, len)); if (gate != null) { try { long ackTimeoutMs = 30_000L; @@ -524,13 +522,20 @@ private void sendChunk(SessionContext context, String instanceId, String jobId, } } + private void sendBinaryChunk(SessionContext context, ByteBuffer payload) { + synchronized (context.webSocket) { + context.webSocket.sendBinary(payload, true).join(); + } + } + private void sendEnvelope(SessionContext context, AgentWsEnvelope envelope) throws IOException { synchronized (context.webSocket) { context.webSocket.sendText(envelope.toJson(), true).join(); } } - private void handleText(String instanceId, String text, CompletableFuture helloFuture) { + private void handleText(String instanceId, WebSocket webSocket, String text, + CompletableFuture helloFuture) { try { AgentWsEnvelope envelope = AgentWsEnvelope.fromJson(text); if (envelope.getType() == null) { @@ -541,13 +546,10 @@ private void handleText(String instanceId, String text, CompletableFuture helloFuture.complete(envelope); case ack -> handleAck(envelope); - case file_ack -> handleFileAck(agentId, envelope); + case file_ack -> handleFileAck(agentId, webSocket, envelope); case status_update -> handleStatusUpdate(agentId, envelope); case pong -> LOG.debug(new ObjectMessage(Map.of("Message", "[WS] Pong from " + agentId))); - case close -> { - SessionContext ctx = sessions.get(agentId); - onClosed(agentId, ctx != null ? ctx.webSocket : null); - } + case close -> onClosed(agentId, webSocket); default -> { } } @@ -565,7 +567,24 @@ private void handleAck(AgentWsEnvelope envelope) { } } - private void handleFileAck(String instanceId, AgentWsEnvelope envelope) { + private void handleFileAck(String instanceId, WebSocket webSocket, AgentWsEnvelope envelope) { + SessionContext context = sessions.get(instanceId); + if (context == null) { + LOG.info(new ObjectMessage(Map.of("Message", + "[WS] Ignoring file_ack for non-active session " + instanceId))); + return; + } + if (webSocket == null) { + LOG.info(new ObjectMessage(Map.of("Message", + "[WS] Ignoring file_ack without WebSocket identity for " + instanceId))); + return; + } + if (context.webSocket != webSocket) { + LOG.info(new ObjectMessage(Map.of("Message", + "[WS] Ignoring stale file_ack from previous session for " + instanceId))); + return; + } + // Route offer-level acks (ok, resume, failed) to the pending offer future if (envelope.getFileId() != null && (envelope.getStatus() == AckStatus.ok || envelope.getStatus() == AckStatus.resume @@ -577,7 +596,6 @@ private void handleFileAck(String instanceId, AgentWsEnvelope envelope) { } } - SessionContext context = sessions.get(instanceId); if (envelope.getStatus() == AckStatus.all_files_complete) { fileTransferReady.put(instanceId, true); agentWsState.put(instanceId, "ready"); @@ -586,19 +604,15 @@ private void handleFileAck(String instanceId, AgentWsEnvelope envelope) { pendingChunkAcks.remove(instanceId, pending); pending.future.complete(null); } - if (context != null) { - context.transferCompleteFuture.complete(null); - } + context.transferCompleteFuture.complete(null); return; } if (envelope.getStatus() == AckStatus.complete) { - if (context != null) { - int completed = context.completedFiles.merge(instanceId, 1, Integer::sum); - agentTransferProgress.put(instanceId, completed + "/" + context.expectedFiles + " files"); - if (context.bootstrapTransfer && completed >= context.expectedFiles) { - context.transferCompleteFuture.complete(null); - } + int completed = context.completedFiles.merge(instanceId, 1, Integer::sum); + agentTransferProgress.put(instanceId, completed + "/" + context.expectedFiles + " files"); + if (context.bootstrapTransfer && completed >= context.expectedFiles) { + context.transferCompleteFuture.complete(null); } } @@ -614,9 +628,7 @@ private void handleFileAck(String instanceId, AgentWsEnvelope envelope) { pendingChunkAcks.remove(instanceId, pending); pending.future.completeExceptionally(new IOException(envelope.getError())); } - if (context != null) { - context.transferCompleteFuture.completeExceptionally(new IOException(envelope.getError())); - } + context.transferCompleteFuture.completeExceptionally(new IOException(envelope.getError())); } } @@ -640,14 +652,22 @@ private void onClosed(String instanceId, WebSocket webSocket) { "[WS] Ignoring close for non-active session " + instanceId))); return; } - if (webSocket != null && context.webSocket != webSocket) { + if (webSocket == null) { + LOG.info(new ObjectMessage(Map.of("Message", + "[WS] Ignoring close without WebSocket identity for " + instanceId))); + return; + } + if (context.webSocket != webSocket) { LOG.info(new ObjectMessage(Map.of("Message", "[WS] Ignoring stale close from previous session for " + instanceId))); return; } - if (sessions.remove(instanceId, context)) { - context.markClosed(); + if (!sessions.remove(instanceId, context)) { + LOG.info(new ObjectMessage(Map.of("Message", + "[WS] Ignoring close for replaced session " + instanceId))); + return; } + context.markClosed(); fileTransferReady.remove(instanceId); PendingChunkAck pending = pendingChunkAcks.remove(instanceId); if (pending != null) { @@ -677,7 +697,7 @@ public CompletionStage onText(WebSocket webSocket, CharSequence data, boolean if (last) { String fullMessage = messageBuffer.toString(); messageBuffer.setLength(0); - handleText(instanceId, fullMessage, helloFuture); + handleText(instanceId, webSocket, fullMessage, helloFuture); } webSocket.request(1); return null;