Skip to content

Commit 2283116

Browse files
committed
prefer as_ cast
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent ecbec13 commit 2283116

1 file changed

Lines changed: 39 additions & 76 deletions

File tree

vortex-cuda/src/dynamic_dispatch/plan_builder.rs

Lines changed: 39 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,15 @@ use vortex::array::ExecutionCtx;
1414
use vortex::array::arrays::Dict;
1515
use vortex::array::arrays::Primitive;
1616
use vortex::array::arrays::Slice;
17-
use vortex::array::arrays::primitive::PrimitiveArrayParts;
1817
use vortex::array::buffer::BufferHandle;
1918
use vortex::array::session::ArraySession;
2019
use vortex::dtype::PType;
2120
use vortex::encodings::alp::ALP;
2221
use vortex::encodings::alp::ALPFloat;
2322
use vortex::encodings::fastlanes::BitPacked;
24-
use vortex::encodings::fastlanes::BitPackedArrayParts;
2523
use vortex::encodings::fastlanes::FoR;
2624
use vortex::encodings::runend::RunEnd;
27-
use vortex::encodings::runend::RunEndArrayParts;
2825
use vortex::encodings::sequence::Sequence;
29-
use vortex::encodings::sequence::SequenceArrayParts;
3026
use vortex::encodings::zigzag::ZigZag;
3127
use vortex::error::VortexResult;
3228
use vortex::error::vortex_bail;
@@ -55,44 +51,35 @@ pub struct MaterializedPlan {
5551
fn is_dyn_dispatch_compatible(array: &ArrayRef) -> bool {
5652
let id = array.encoding_id();
5753
if id == ALP::ID {
58-
if let Ok(a) = array.clone().try_into::<ALP>() {
59-
return a.patches().is_none() && a.dtype().as_ptype() == PType::F32;
60-
}
61-
return false;
54+
let arr = array.as_::<ALP>();
55+
return arr.patches().is_none() && arr.dtype().as_ptype() == PType::F32;
6256
}
6357
if id == BitPacked::ID {
64-
if let Ok(a) = array.clone().try_into::<BitPacked>() {
65-
return a.patches().is_none();
66-
}
67-
return false;
58+
return array.as_::<BitPacked>().patches().is_none();
6859
}
6960
if id == Dict::ID {
70-
if let Ok(a) = array.clone().try_into::<Dict>() {
71-
// As of now the dict dyn dispatch kernel requires
72-
// codes and values to have the same byte width.
73-
return match (
74-
PType::try_from(a.values().dtype()),
75-
PType::try_from(a.codes().dtype()),
76-
) {
77-
(Ok(values), Ok(codes)) => values.byte_width() == codes.byte_width(),
78-
_ => false,
79-
};
80-
}
81-
return false;
61+
let arr = array.as_::<Dict>();
62+
// As of now the dict dyn dispatch kernel requires
63+
// codes and values to have the same byte width.
64+
return match (
65+
PType::try_from(arr.values().dtype()),
66+
PType::try_from(arr.codes().dtype()),
67+
) {
68+
(Ok(values), Ok(codes)) => values.byte_width() == codes.byte_width(),
69+
_ => false,
70+
};
8271
}
8372
if id == RunEnd::ID {
84-
if let Ok(a) = array.clone().try_into::<RunEnd>() {
85-
// As of now the run-end dyn dispatch kernel requires
86-
// ends and values to have the same byte width.
87-
return match (
88-
PType::try_from(a.ends().dtype()),
89-
PType::try_from(a.values().dtype()),
90-
) {
91-
(Ok(e), Ok(v)) => e.byte_width() == v.byte_width(),
92-
_ => false,
93-
};
94-
}
95-
return false;
73+
let arr = array.as_::<RunEnd>();
74+
// As of now the run-end dyn dispatch kernel requires
75+
// ends and values to have the same byte width.
76+
return match (
77+
PType::try_from(arr.ends().dtype()),
78+
PType::try_from(arr.values().dtype()),
79+
) {
80+
(Ok(e), Ok(v)) => e.byte_width() == v.byte_width(),
81+
_ => false,
82+
};
9683
}
9784
id == FoR::ID
9885
|| id == ZigZag::ID
@@ -429,33 +416,23 @@ impl FusedPlan {
429416
}
430417

431418
fn walk_primitive(&mut self, array: ArrayRef) -> VortexResult<Stage> {
432-
let prim = array.to_canonical()?.into_primitive();
433-
let PrimitiveArrayParts { buffer, .. } = prim.into_parts();
419+
let prim = array.as_::<Primitive>();
434420
let buf_index = self.source_buffers.len();
435-
self.source_buffers.push(Some(buffer));
421+
self.source_buffers.push(Some(prim.buffer_handle().clone()));
436422
Ok(Stage::new(SourceOp::load(), Some(buf_index)))
437423
}
438424

439425
fn walk_bitpacked(&mut self, array: ArrayRef) -> VortexResult<Stage> {
440-
let bp = array
441-
.try_into::<BitPacked>()
442-
.map_err(|_| vortex_err!("Expected BitPackedArray"))?;
443-
let BitPackedArrayParts {
444-
offset,
445-
bit_width,
446-
packed,
447-
patches,
448-
..
449-
} = bp.into_parts();
450-
451-
if patches.is_some() {
426+
let bp = array.as_::<BitPacked>();
427+
428+
if bp.patches().is_some() {
452429
vortex_bail!("Dynamic dispatch does not support BitPackedArray with patches");
453430
}
454431

455432
let buf_index = self.source_buffers.len();
456-
self.source_buffers.push(Some(packed));
433+
self.source_buffers.push(Some(bp.packed().clone()));
457434
Ok(Stage::new(
458-
SourceOp::bitunpack(bit_width, offset),
435+
SourceOp::bitunpack(bp.bit_width(), bp.offset()),
459436
Some(buf_index),
460437
))
461438
}
@@ -465,9 +442,7 @@ impl FusedPlan {
465442
array: ArrayRef,
466443
pending_subtrees: &mut Vec<ArrayRef>,
467444
) -> VortexResult<Stage> {
468-
let for_arr = array
469-
.try_into::<FoR>()
470-
.map_err(|_| vortex_err!("Expected FoRArray"))?;
445+
let for_arr = array.as_::<FoR>();
471446
let ref_pvalue = for_arr
472447
.reference_scalar()
473448
.as_primitive()
@@ -488,9 +463,7 @@ impl FusedPlan {
488463
array: ArrayRef,
489464
pending_subtrees: &mut Vec<ArrayRef>,
490465
) -> VortexResult<Stage> {
491-
let zz = array
492-
.try_into::<ZigZag>()
493-
.map_err(|_| vortex_err!("Expected ZigZagArray"))?;
466+
let zz = array.as_::<ZigZag>();
494467
let encoded = zz.encoded().clone();
495468

496469
let mut pipeline = self.walk(encoded, pending_subtrees)?;
@@ -503,9 +476,7 @@ impl FusedPlan {
503476
array: ArrayRef,
504477
pending_subtrees: &mut Vec<ArrayRef>,
505478
) -> VortexResult<Stage> {
506-
let alp = array
507-
.try_into::<ALP>()
508-
.map_err(|_| vortex_err!("Expected ALPArray"))?;
479+
let alp = array.as_::<ALP>();
509480

510481
if alp.patches().is_some() {
511482
vortex_bail!("Dynamic dispatch does not support ALPArray with patches");
@@ -534,9 +505,7 @@ impl FusedPlan {
534505
array: ArrayRef,
535506
pending_subtrees: &mut Vec<ArrayRef>,
536507
) -> VortexResult<Stage> {
537-
let dict = array
538-
.try_into::<Dict>()
539-
.map_err(|_| vortex_err!("Expected DictArray"))?;
508+
let dict = array.as_::<Dict>();
540509
let values = dict.values().clone();
541510
let codes = dict.codes().clone();
542511

@@ -550,15 +519,10 @@ impl FusedPlan {
550519
}
551520

552521
fn walk_sequence(&mut self, array: ArrayRef) -> VortexResult<Stage> {
553-
let seq = array
554-
.try_into::<Sequence>()
555-
.map_err(|_| vortex_err!("Expected SequenceArray"))?;
556-
let SequenceArrayParts {
557-
base, multiplier, ..
558-
} = seq.into_parts();
522+
let seq = array.as_::<Sequence>();
559523

560524
Ok(Stage::new(
561-
SourceOp::sequence(base.cast()?, multiplier.cast()?),
525+
SourceOp::sequence(seq.base().cast()?, seq.multiplier().cast()?),
562526
None,
563527
))
564528
}
@@ -568,11 +532,10 @@ impl FusedPlan {
568532
array: ArrayRef,
569533
pending_subtrees: &mut Vec<ArrayRef>,
570534
) -> VortexResult<Stage> {
571-
let re = array
572-
.try_into::<RunEnd>()
573-
.map_err(|_| vortex_err!("Expected RunEndArray"))?;
535+
let re = array.as_::<RunEnd>();
574536
let offset = re.offset() as u64;
575-
let RunEndArrayParts { ends, values } = re.into_parts();
537+
let ends = re.ends().clone();
538+
let values = re.values().clone();
576539
let num_runs = ends.len() as u32;
577540
let num_values = values.len() as u32;
578541

0 commit comments

Comments
 (0)