Skip to content

Commit 10c3fc6

Browse files
authored
Assert_arrays_eq also executes the array and compares the results against scalar_at (#7222)
Execute arrays passed to assert_arrays_eq to check any execute_parent rules that might exist Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 9f10dfa commit 10c3fc6

2 files changed

Lines changed: 86 additions & 32 deletions

File tree

vortex-array/src/arrays/assertions.rs

Lines changed: 85 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,38 @@
44
use std::fmt::Display;
55

66
use itertools::Itertools;
7+
use vortex_error::VortexExpect;
78

8-
pub fn format_indices<I: IntoIterator<Item = usize>>(indices: I) -> impl Display {
9+
use crate::ArrayRef;
10+
use crate::DynArray;
11+
use crate::ExecutionCtx;
12+
use crate::IntoArray;
13+
use crate::LEGACY_SESSION;
14+
use crate::RecursiveCanonical;
15+
use crate::VortexSessionExecute;
16+
17+
fn format_indices<I: IntoIterator<Item = usize>>(indices: I) -> impl Display {
918
indices.into_iter().format(",")
1019
}
1120

21+
/// Executes an array to recursive canonical form with the given execution context.
22+
fn execute_to_canonical(array: ArrayRef, ctx: &mut ExecutionCtx) -> ArrayRef {
23+
array
24+
.execute::<RecursiveCanonical>(ctx)
25+
.vortex_expect("failed to execute array to recursive canonical form")
26+
.0
27+
.into_array()
28+
}
29+
30+
/// Finds indices where two arrays differ based on `scalar_at` comparison.
31+
#[expect(clippy::unwrap_used)]
32+
fn find_mismatched_indices(left: &ArrayRef, right: &ArrayRef) -> Vec<usize> {
33+
assert_eq!(left.len(), right.len());
34+
(0..left.len())
35+
.filter(|i| left.scalar_at(*i).unwrap() != right.scalar_at(*i).unwrap())
36+
.collect()
37+
}
38+
1239
/// Asserts that the scalar at position `$n` in array `$arr` equals `$expected`.
1340
///
1441
/// This is a convenience macro for testing that avoids verbose scalar comparison code.
@@ -51,37 +78,64 @@ macro_rules! assert_arrays_eq {
5178
($left:expr, $right:expr) => {{
5279
let left = $left.clone();
5380
let right = $right.clone();
54-
if left.dtype() != right.dtype() {
55-
panic!(
56-
"assertion left == right failed: arrays differ in type: {} != {}.\n left: {}\n right: {}",
57-
left.dtype(),
58-
right.dtype(),
59-
left.display_values(),
60-
right.display_values()
61-
)
62-
}
81+
assert_eq!(
82+
left.dtype(),
83+
right.dtype(),
84+
"assertion left == right failed: arrays differ in type: {} != {}.\n left: {}\n right: {}",
85+
left.dtype(),
86+
right.dtype(),
87+
left.display_values(),
88+
right.display_values()
89+
);
6390

64-
if left.len() != right.len() {
65-
panic!(
66-
"assertion left == right failed: arrays differ in length: {} != {}.\n left: {}\n right: {}",
67-
left.len(),
68-
right.len(),
69-
left.display_values(),
70-
right.display_values()
71-
)
72-
}
91+
assert_eq!(
92+
left.len(),
93+
right.len(),
94+
"assertion left == right failed: arrays differ in length: {} != {}.\n left: {}\n right: {}",
95+
left.len(),
96+
right.len(),
97+
left.display_values(),
98+
right.display_values()
99+
);
73100

74-
let n = left.len();
75-
let mismatched_indices = (0..n)
76-
.filter(|i| left.scalar_at(*i).unwrap() != right.scalar_at(*i).unwrap())
77-
.collect::<Vec<_>>();
78-
if mismatched_indices.len() != 0 {
79-
panic!(
80-
"assertion left == right failed: arrays do not match at indices: {}.\n left: {}\n right: {}",
81-
$crate::arrays::format_indices(mismatched_indices),
82-
left.display_values(),
83-
right.display_values()
84-
)
85-
}
101+
#[allow(deprecated)]
102+
let left = left.to_array();
103+
#[allow(deprecated)]
104+
let right = right.to_array();
105+
$crate::arrays::assert_arrays_eq_impl(&left, &right);
86106
}};
87107
}
108+
109+
/// Implementation of `assert_arrays_eq!` — called by the macro after converting inputs to
110+
/// `ArrayRef`.
111+
#[track_caller]
112+
#[allow(clippy::panic)]
113+
pub fn assert_arrays_eq_impl(left: &ArrayRef, right: &ArrayRef) {
114+
let executed = execute_to_canonical(left.clone(), &mut LEGACY_SESSION.create_execution_ctx());
115+
116+
let left_right = find_mismatched_indices(left, right);
117+
let executed_right = find_mismatched_indices(&executed, right);
118+
119+
if !left_right.is_empty() || !executed_right.is_empty() {
120+
let mut msg = String::new();
121+
if !left_right.is_empty() {
122+
msg.push_str(&format!(
123+
"\n left != right at indices: {}",
124+
format_indices(left_right)
125+
));
126+
}
127+
if !executed_right.is_empty() {
128+
msg.push_str(&format!(
129+
"\n executed != right at indices: {}",
130+
format_indices(executed_right)
131+
));
132+
}
133+
panic!(
134+
"assertion failed: arrays do not match:{}\n left: {}\n right: {}\n executed: {}",
135+
msg,
136+
left.display_values(),
137+
right.display_values(),
138+
executed.display_values()
139+
)
140+
}
141+
}

vortex-array/src/arrays/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
mod assertions;
88

99
#[cfg(any(test, feature = "_test-harness"))]
10-
pub use assertions::format_indices;
10+
pub use assertions::assert_arrays_eq_impl;
1111

1212
#[cfg(test)]
1313
mod validation_tests;

0 commit comments

Comments
 (0)