|
9 | 9 |
|
10 | 10 | use std::sync::Arc; |
11 | 11 |
|
12 | | -use futures::executor::block_on; |
13 | 12 | use vortex::array::ArrayRef; |
14 | 13 | use vortex::array::DynArray; |
15 | 14 | use vortex::array::ExecutionCtx; |
@@ -324,7 +323,18 @@ impl PlanBuilderState<'_> { |
324 | 323 | fn walk_primitive(&mut self, array: ArrayRef) -> VortexResult<Pipeline> { |
325 | 324 | let prim = array.to_canonical()?.into_primitive(); |
326 | 325 | let PrimitiveArrayParts { buffer, .. } = prim.into_parts(); |
327 | | - let device_buf = block_on(self.ctx.ensure_on_device(buffer))?; |
| 326 | + |
| 327 | + // TODO(0ax1): Optimize device buffer allocation and copying. |
| 328 | + // |
| 329 | + // Ideally, there would be a buffer pool of preallocated device memory |
| 330 | + // such that retrieving a device pointer is O(1) when building the |
| 331 | + // dynamic dispatch plan. In the current setup, we need to allocate the |
| 332 | + // buffer before we can get the device pointer. As the memory is |
| 333 | + // allocated via the global allocator, which does not pin the host |
| 334 | + // memory to physical addresses unlike `cudaHostAlloc`, the subsequent |
| 335 | + // memory copy from host to device is sync and cannot be pushed to the |
| 336 | + // CUDA stream as an async operation. |
| 337 | + let device_buf = self.ctx.ensure_on_device_sync(buffer)?; |
328 | 338 | let ptr = device_buf.cuda_device_ptr()?; |
329 | 339 | self.device_buffers.push(device_buf); |
330 | 340 | Ok(Pipeline { |
@@ -354,7 +364,7 @@ impl PlanBuilderState<'_> { |
354 | 364 | vortex_bail!("Dynamic dispatch does not support BitPackedArray with patches"); |
355 | 365 | } |
356 | 366 |
|
357 | | - let device_buf = block_on(self.ctx.ensure_on_device(packed))?; |
| 367 | + let device_buf = self.ctx.ensure_on_device_sync(packed)?; |
358 | 368 | let ptr = device_buf.cuda_device_ptr()?; |
359 | 369 | self.device_buffers.push(device_buf); |
360 | 370 | Ok(Pipeline { |
@@ -490,14 +500,26 @@ impl PlanBuilderState<'_> { |
490 | 500 | } |
491 | 501 |
|
492 | 502 | /// Extract a FoR reference scalar as u64 bits. |
| 503 | +/// |
| 504 | +/// `TryFrom<&Scalar>` for primitive types requires an exact ptype match, |
| 505 | +/// so we must try each width individually rather than relying on widening. |
493 | 506 | fn extract_for_reference(for_arr: &FoRArray) -> VortexResult<u64> { |
494 | | - if let Ok(v) = u32::try_from(for_arr.reference_scalar()) { |
| 507 | + let s = for_arr.reference_scalar(); |
| 508 | + if let Ok(v) = u8::try_from(s) { |
| 509 | + Ok(v as u64) |
| 510 | + } else if let Ok(v) = i8::try_from(s) { |
| 511 | + Ok(v as u8 as u64) |
| 512 | + } else if let Ok(v) = u16::try_from(s) { |
| 513 | + Ok(v as u64) |
| 514 | + } else if let Ok(v) = i16::try_from(s) { |
| 515 | + Ok(v as u16 as u64) |
| 516 | + } else if let Ok(v) = u32::try_from(s) { |
495 | 517 | Ok(v as u64) |
496 | | - } else if let Ok(v) = i32::try_from(for_arr.reference_scalar()) { |
| 518 | + } else if let Ok(v) = i32::try_from(s) { |
497 | 519 | Ok(v as u32 as u64) |
498 | | - } else if let Ok(v) = u64::try_from(for_arr.reference_scalar()) { |
| 520 | + } else if let Ok(v) = u64::try_from(s) { |
499 | 521 | Ok(v) |
500 | | - } else if let Ok(v) = i64::try_from(for_arr.reference_scalar()) { |
| 522 | + } else if let Ok(v) = i64::try_from(s) { |
501 | 523 | Ok(v as u64) |
502 | 524 | } else { |
503 | 525 | vortex_bail!("Cannot extract FoR reference as an integer type") |
|
0 commit comments