|
| 1 | +package com.databricks.jdbc.api.impl.arrow; |
| 2 | + |
| 3 | +import static com.databricks.jdbc.common.DatabricksJdbcConstants.ARROW_METADATA_KEY; |
| 4 | + |
| 5 | +import com.databricks.jdbc.common.CompressionCodec; |
| 6 | +import com.databricks.jdbc.common.util.DriverUtil; |
| 7 | +import com.databricks.jdbc.dbclient.IDatabricksHttpClient; |
| 8 | +import com.databricks.jdbc.dbclient.impl.common.StatementId; |
| 9 | +import com.databricks.jdbc.exception.DatabricksParsingException; |
| 10 | +import com.databricks.jdbc.exception.DatabricksSQLException; |
| 11 | +import com.databricks.jdbc.log.JdbcLogger; |
| 12 | +import com.databricks.jdbc.log.JdbcLoggerFactory; |
| 13 | +import com.databricks.jdbc.model.core.ExternalLink; |
| 14 | +import java.io.IOException; |
| 15 | +import java.io.InputStream; |
| 16 | +import java.nio.channels.ClosedByInterruptException; |
| 17 | +import java.time.Instant; |
| 18 | +import java.util.ArrayList; |
| 19 | +import java.util.List; |
| 20 | +import java.util.concurrent.CompletableFuture; |
| 21 | +import java.util.concurrent.ExecutionException; |
| 22 | +import java.util.concurrent.TimeUnit; |
| 23 | +import java.util.concurrent.TimeoutException; |
| 24 | +import java.util.stream.Collectors; |
| 25 | +import org.apache.arrow.memory.BufferAllocator; |
| 26 | +import org.apache.arrow.memory.RootAllocator; |
| 27 | +import org.apache.arrow.vector.ValueVector; |
| 28 | +import org.apache.arrow.vector.VectorSchemaRoot; |
| 29 | +import org.apache.arrow.vector.ipc.ArrowStreamReader; |
| 30 | +import org.apache.arrow.vector.util.TransferPair; |
| 31 | +import org.apache.commons.lang3.exception.ExceptionUtils; |
| 32 | + |
| 33 | +/** |
| 34 | + * An abstract class that represents a chunk of query result. |
| 35 | + * |
| 36 | + * <p>This class provides methods for downloading, processing, and releasing the data in the chunk. |
| 37 | + * It also manages the state of the chunk and provides access to the data as Arrow record batches. |
| 38 | + */ |
| 39 | +public abstract class AbstractArrowResultChunk { |
| 40 | + private static final JdbcLogger LOGGER = |
| 41 | + JdbcLoggerFactory.getLogger(AbstractArrowResultChunk.class); |
| 42 | + |
| 43 | + protected static final Integer SECONDS_BUFFER_FOR_EXPIRY = 60; |
| 44 | + protected static final long CHUNK_READY_TIMEOUT_SECONDS = 30; |
| 45 | + protected final long numRows; |
| 46 | + protected final long rowOffset; |
| 47 | + protected final long chunkIndex; |
| 48 | + protected final StatementId statementId; |
| 49 | + protected final BufferAllocator rootAllocator; |
| 50 | + |
| 51 | + /** |
| 52 | + * Future to track when the chunk becomes ready for consumption. This includes both the download |
| 53 | + * and processing phases. The state of the Future is updated by the {@link ChunkDownloadTask} and |
| 54 | + * indicates when the chunk's data is fully processed and available for use. |
| 55 | + */ |
| 56 | + protected final CompletableFuture<Void> chunkReadyFuture; |
| 57 | + |
| 58 | + protected final ArrowResultChunkStateMachine stateMachine; |
| 59 | + protected List<List<ValueVector>> recordBatchList; |
| 60 | + protected ExternalLink chunkLink; |
| 61 | + protected Instant expiryTime; |
| 62 | + protected String errorMessage; |
| 63 | + protected List<String> arrowMetadata; |
| 64 | + |
| 65 | + static final class ArrowData { |
| 66 | + private final List<List<ValueVector>> valueVectors; |
| 67 | + private final List<String> metadata; |
| 68 | + |
| 69 | + public ArrowData(List<List<ValueVector>> valueVectors, List<String> metadata) { |
| 70 | + this.valueVectors = valueVectors; |
| 71 | + this.metadata = metadata; |
| 72 | + } |
| 73 | + |
| 74 | + public List<List<ValueVector>> getValueVectors() { |
| 75 | + return valueVectors; |
| 76 | + } |
| 77 | + |
| 78 | + public List<String> getMetadata() { |
| 79 | + return metadata; |
| 80 | + } |
| 81 | + } |
| 82 | + |
| 83 | + protected AbstractArrowResultChunk( |
| 84 | + long numRows, |
| 85 | + long rowOffset, |
| 86 | + long chunkIndex, |
| 87 | + StatementId statementId, |
| 88 | + ChunkStatus initialStatus, |
| 89 | + ExternalLink chunkLink, |
| 90 | + Instant expiryTime) { |
| 91 | + this.numRows = numRows; |
| 92 | + this.rowOffset = rowOffset; |
| 93 | + this.chunkIndex = chunkIndex; |
| 94 | + this.statementId = statementId; |
| 95 | + this.rootAllocator = new RootAllocator(Integer.MAX_VALUE); |
| 96 | + this.chunkReadyFuture = new CompletableFuture<>(); |
| 97 | + this.chunkLink = chunkLink; |
| 98 | + this.expiryTime = expiryTime; |
| 99 | + this.stateMachine = new ArrowResultChunkStateMachine(initialStatus, chunkIndex, statementId); |
| 100 | + } |
| 101 | + |
| 102 | + /** |
| 103 | + * Returns the index of this chunk. |
| 104 | + * |
| 105 | + * @return chunk index |
| 106 | + */ |
| 107 | + public Long getChunkIndex() { |
| 108 | + return chunkIndex; |
| 109 | + } |
| 110 | + |
| 111 | + /** |
| 112 | + * Checks if the chunk link is invalid or expired. |
| 113 | + * |
| 114 | + * @return true if link is invalid, false otherwise |
| 115 | + */ |
| 116 | + public boolean isChunkLinkInvalid() { |
| 117 | + return getStatus() == ChunkStatus.PENDING |
| 118 | + || (!DriverUtil.isRunningAgainstFake() |
| 119 | + && expiryTime.minusSeconds(SECONDS_BUFFER_FOR_EXPIRY).isBefore(Instant.now())); |
| 120 | + } |
| 121 | + |
| 122 | + /** |
| 123 | + * Releases all resources associated with this chunk. |
| 124 | + * |
| 125 | + * @return true if chunk was released, false if it was already released |
| 126 | + */ |
| 127 | + public boolean releaseChunk() { |
| 128 | + if (getStatus() == ChunkStatus.CHUNK_RELEASED) { |
| 129 | + return false; |
| 130 | + } |
| 131 | + |
| 132 | + if (getStatus() == ChunkStatus.PROCESSING_SUCCEEDED) { |
| 133 | + logAllocatorStats("BeforeRelease"); |
| 134 | + purgeArrowData(this.recordBatchList); |
| 135 | + rootAllocator.close(); |
| 136 | + } |
| 137 | + setStatus(ChunkStatus.CHUNK_RELEASED); |
| 138 | + |
| 139 | + return true; |
| 140 | + } |
| 141 | + |
| 142 | + /** |
| 143 | + * Downloads and initializes data for this chunk using the provided HTTP client and compression |
| 144 | + * codec. |
| 145 | + * |
| 146 | + * @param httpClient the HTTP client to use for downloading |
| 147 | + * @param compressionCodec the compression codec to use for decompression |
| 148 | + * @throws DatabricksParsingException if there is an error parsing the data |
| 149 | + * @throws IOException if there is an error downloading or reading the data |
| 150 | + */ |
| 151 | + protected abstract void downloadData( |
| 152 | + IDatabricksHttpClient httpClient, CompressionCodec compressionCodec) |
| 153 | + throws DatabricksParsingException, IOException; |
| 154 | + |
| 155 | + /** Handles a failure during the download or processing of this chunk. */ |
| 156 | + protected abstract void handleFailure(Exception exception, ChunkStatus failedStatus) |
| 157 | + throws DatabricksParsingException; |
| 158 | + |
| 159 | + /** |
| 160 | + * Returns the number of record batches in the chunk. |
| 161 | + * |
| 162 | + * @return number of record batches |
| 163 | + */ |
| 164 | + protected int getRecordBatchCountInChunk() { |
| 165 | + return getStatus() == ChunkStatus.PROCESSING_SUCCEEDED ? recordBatchList.size() : 0; |
| 166 | + } |
| 167 | + |
| 168 | + /** |
| 169 | + * Returns the list of record batches, where each record batch is a list of value vectors. |
| 170 | + * |
| 171 | + * @return List of record batches |
| 172 | + */ |
| 173 | + protected List<List<ValueVector>> getRecordBatchList() { |
| 174 | + return recordBatchList; |
| 175 | + } |
| 176 | + |
| 177 | + /** |
| 178 | + * Returns the total number of rows in the chunk. |
| 179 | + * |
| 180 | + * @return number of rows |
| 181 | + */ |
| 182 | + protected long getNumRows() { |
| 183 | + return numRows; |
| 184 | + } |
| 185 | + |
| 186 | + /** |
| 187 | + * Returns the value vector for a specific record batch and column. |
| 188 | + * |
| 189 | + * @param recordBatchIndex index of the record batch |
| 190 | + * @param columnIndex index of the column |
| 191 | + * @return ValueVector for the specified position |
| 192 | + */ |
| 193 | + protected ValueVector getColumnVector(int recordBatchIndex, int columnIndex) { |
| 194 | + return recordBatchList.get(recordBatchIndex).get(columnIndex); |
| 195 | + } |
| 196 | + |
| 197 | + /** |
| 198 | + * Returns the current status of the chunk. |
| 199 | + * |
| 200 | + * @return current ChunkStatus |
| 201 | + */ |
| 202 | + protected ChunkStatus getStatus() { |
| 203 | + return stateMachine.getCurrentStatus(); |
| 204 | + } |
| 205 | + |
| 206 | + /** |
| 207 | + * Updates the status of the chunk. |
| 208 | + * |
| 209 | + * @param targetStatus new status to set |
| 210 | + */ |
| 211 | + protected void setStatus(ChunkStatus targetStatus) { |
| 212 | + try { |
| 213 | + stateMachine.transition(targetStatus); |
| 214 | + } catch (DatabricksParsingException e) { |
| 215 | + LOGGER.warn( |
| 216 | + "Failed to transition to state [%s] from state [%s] for chunk [%d] and statement [%s]. Stack trace: %s", |
| 217 | + targetStatus, getStatus(), chunkIndex, statementId, ExceptionUtils.getStackTrace(e)); |
| 218 | + } |
| 219 | + } |
| 220 | + |
| 221 | + /** |
| 222 | + * Returns an iterator for traversing the rows in this chunk. |
| 223 | + * |
| 224 | + * @return ArrowResultChunkIterator for this chunk |
| 225 | + */ |
| 226 | + protected ArrowResultChunkIterator getChunkIterator() { |
| 227 | + return new ArrowResultChunkIterator(this); |
| 228 | + } |
| 229 | + |
| 230 | + /** |
| 231 | + * Sets the external link details for this chunk. |
| 232 | + * |
| 233 | + * @param chunk the external link information |
| 234 | + */ |
| 235 | + protected void setChunkLink(ExternalLink chunk) { |
| 236 | + chunkLink = chunk; |
| 237 | + expiryTime = Instant.parse(chunk.getExpiration()); |
| 238 | + setStatus(ChunkStatus.URL_FETCHED); |
| 239 | + } |
| 240 | + |
| 241 | + protected CompletableFuture<Void> getChunkReadyFuture() { |
| 242 | + return chunkReadyFuture; |
| 243 | + } |
| 244 | + |
| 245 | + /** |
| 246 | + * Waits for the chunk to be ready for consumption. |
| 247 | + * |
| 248 | + * @throws ExecutionException if the chunk download or processing throws an exception |
| 249 | + * @throws InterruptedException if the thread is interrupted while waiting |
| 250 | + * @throws TimeoutException if the chunk is not ready within the timeout |
| 251 | + */ |
| 252 | + protected void waitForChunkReady() |
| 253 | + throws ExecutionException, InterruptedException, TimeoutException { |
| 254 | + try { |
| 255 | + chunkReadyFuture.get(CHUNK_READY_TIMEOUT_SECONDS, TimeUnit.SECONDS); |
| 256 | + } catch (InterruptedException e) { |
| 257 | + Thread.currentThread().interrupt(); |
| 258 | + throw e; |
| 259 | + } |
| 260 | + } |
| 261 | + |
| 262 | + /** |
| 263 | + * Decompresses the given {@link InputStream} and initializes {@link #recordBatchList} from |
| 264 | + * decompressed stream. |
| 265 | + * |
| 266 | + * @param inputStream the input stream to decompress |
| 267 | + * @throws DatabricksSQLException if decompression fails |
| 268 | + * @throws IOException if reading from the stream fails |
| 269 | + */ |
| 270 | + protected void initializeData(InputStream inputStream) |
| 271 | + throws DatabricksSQLException, IOException { |
| 272 | + LOGGER.debug("Parsing data for chunk index %s and statement %s", chunkIndex, statementId); |
| 273 | + ArrowData arrowData = getRecordBatchList(inputStream, rootAllocator, statementId, chunkIndex); |
| 274 | + recordBatchList = arrowData.getValueVectors(); |
| 275 | + arrowMetadata = arrowData.getMetadata(); |
| 276 | + LOGGER.debug("Data parsed for chunk index %s and statement %s", chunkIndex, statementId); |
| 277 | + } |
| 278 | + |
| 279 | + protected List<String> getArrowMetadata() { |
| 280 | + return arrowMetadata; |
| 281 | + } |
| 282 | + |
| 283 | + /** |
| 284 | + * Reads Arrow format data from an input stream and converts it into a list of record batches. |
| 285 | + * Each record batch is represented as a list of {@link ValueVector}s. |
| 286 | + */ |
| 287 | + private ArrowData getRecordBatchList( |
| 288 | + InputStream inputStream, |
| 289 | + BufferAllocator rootAllocator, |
| 290 | + StatementId statementId, |
| 291 | + long chunkIndex) |
| 292 | + throws IOException { |
| 293 | + List<List<ValueVector>> recordBatchList = new ArrayList<>(); |
| 294 | + List<String> metadata = new ArrayList<>(); |
| 295 | + try (ArrowStreamReader arrowStreamReader = new ArrowStreamReader(inputStream, rootAllocator)) { |
| 296 | + VectorSchemaRoot vectorSchemaRoot = arrowStreamReader.getVectorSchemaRoot(); |
| 297 | + boolean fetchedMetadata = false; |
| 298 | + while (arrowStreamReader.loadNextBatch()) { |
| 299 | + if (!fetchedMetadata) { |
| 300 | + metadata = getMetadataInformationFromSchemaRoot(vectorSchemaRoot); |
| 301 | + fetchedMetadata = true; |
| 302 | + } |
| 303 | + recordBatchList.add(getVectorsFromSchemaRoot(vectorSchemaRoot, rootAllocator)); |
| 304 | + vectorSchemaRoot.clear(); |
| 305 | + } |
| 306 | + } catch (ClosedByInterruptException e) { |
| 307 | + // release resources if thread is interrupted when reading arrow data |
| 308 | + LOGGER.error( |
| 309 | + e, |
| 310 | + "Data parsing interrupted for chunk index [%s] and statement [%s]. Error [%s]", |
| 311 | + chunkIndex, |
| 312 | + statementId, |
| 313 | + e.getMessage()); |
| 314 | + purgeArrowData(recordBatchList); |
| 315 | + } catch (IOException e) { |
| 316 | + LOGGER.error( |
| 317 | + "Error while reading arrow data, purging the local list and rethrowing the exception."); |
| 318 | + purgeArrowData(recordBatchList); |
| 319 | + throw e; |
| 320 | + } |
| 321 | + |
| 322 | + return new ArrowData(recordBatchList, metadata); |
| 323 | + } |
| 324 | + |
| 325 | + private List<String> getMetadataInformationFromSchemaRoot(VectorSchemaRoot vectorSchemaRoot) { |
| 326 | + return vectorSchemaRoot.getFieldVectors().stream() |
| 327 | + .map(fieldVector -> fieldVector.getField().getMetadata().get(ARROW_METADATA_KEY)) |
| 328 | + .collect(Collectors.toList()); |
| 329 | + } |
| 330 | + |
| 331 | + /** |
| 332 | + * Transfers the data from the given {@link VectorSchemaRoot} to a list of {@link ValueVector}s. |
| 333 | + */ |
| 334 | + private List<ValueVector> getVectorsFromSchemaRoot( |
| 335 | + VectorSchemaRoot vectorSchemaRoot, BufferAllocator rootAllocator) { |
| 336 | + return vectorSchemaRoot.getFieldVectors().stream() |
| 337 | + .map( |
| 338 | + fieldVector -> { |
| 339 | + TransferPair transferPair = fieldVector.getTransferPair(rootAllocator); |
| 340 | + transferPair.transfer(); |
| 341 | + return transferPair.getTo(); |
| 342 | + }) |
| 343 | + .collect(Collectors.toList()); |
| 344 | + } |
| 345 | + |
| 346 | + private void logAllocatorStats(String event) { |
| 347 | + long allocatedMemory = rootAllocator.getAllocatedMemory(); |
| 348 | + long peakMemory = rootAllocator.getPeakMemoryAllocation(); |
| 349 | + long headRoom = rootAllocator.getHeadroom(); |
| 350 | + long initReservation = rootAllocator.getInitReservation(); |
| 351 | + |
| 352 | + LOGGER.debug( |
| 353 | + "Chunk allocator stats Log - Event: %s, Chunk Index: %s, Allocated Memory: %s, Peak Memory: %s, Headroom: %s, Init Reservation: %s", |
| 354 | + event, chunkIndex, allocatedMemory, peakMemory, headRoom, initReservation); |
| 355 | + } |
| 356 | + |
| 357 | + /** Releases all Arrow-related resources and clears the record batch list. */ |
| 358 | + private void purgeArrowData(List<List<ValueVector>> recordBatchList) { |
| 359 | + recordBatchList.forEach(vectors -> vectors.forEach(ValueVector::close)); |
| 360 | + recordBatchList.clear(); |
| 361 | + } |
| 362 | +} |
0 commit comments