Skip to content

Commit 794f670

Browse files
authored
fix: permute_axes bug (#1596)
As pointed out in #1589, permute_axes had a bug. This uses a simple algorithm to fix it; explanations for this algorithm are all over the internet, it's nothing fancy. I did take the opportunity to introduce proptest, which I'd like to make more use of over time. Down the line, it would be good to change the argument to `permute_axes` to be something like `T: Permutation`. Right now, `permute_axes` just has an assertion that the input is a real permutation. I'd call that a classic example of this library opting for asserts when we could be using better typing. But I think that will be easier / better to introduce after the core rework. Closes #1859
1 parent fbcf3bf commit 794f670

5 files changed

Lines changed: 155 additions & 24 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
toolchain: ${{ matrix.rust }}
4545
components: clippy
4646
- uses: Swatinem/rust-cache@v2
47-
- run: cargo clippy --features approx,serde,rayon
47+
- run: cargo clippy -F "${FEATURES}"
4848

4949
format:
5050
runs-on: ubuntu-latest

Cargo.lock

Lines changed: 108 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ quickcheck = { workspace = true }
5757
approx = { workspace = true, default-features = true }
5858
itertools = { workspace = true }
5959
ndarray-gen = { workspace = true }
60+
proptest = { workspace = true }
6061

6162
[features]
6263
default = ["std"]
@@ -103,6 +104,7 @@ num-traits = { version = "0.2", default-features = false }
103104
num-complex = { version = "0.4", default-features = false }
104105
approx = { version = "0.5", default-features = false }
105106
quickcheck = { version = "1.0", default-features = false }
107+
proptest = { version = "1.3.1" }
106108
rand = { version = "0.9.0", features = ["small_rng"] }
107109
rand_distr = { version = "0.5.0" }
108110
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Seeds for failure cases proptest has generated in the past. It is
2+
# automatically read and these particular cases re-run before any
3+
# novel cases are generated.
4+
#
5+
# It is recommended to check this file in to source control so that
6+
# everyone who runs the test benefits from these saved cases.
7+
cc 6df087160c6028416bca035db64caf430391dda0af3b195969776eb715184310 # shrinks to p = [0, 1, 2, 3, 4, 5]

src/impl_methods.rs

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ use alloc::slice;
1010
use alloc::vec;
1111
#[cfg(not(feature = "std"))]
1212
use alloc::vec::Vec;
13+
#[cfg(test)]
14+
use proptest::prop_assert_eq;
15+
#[cfg(test)]
16+
use proptest::proptest;
17+
#[cfg(test)]
18+
use proptest::strategy::Just;
19+
#[cfg(test)]
20+
use proptest::strategy::Strategy;
1321
#[allow(unused_imports)]
1422
use rawpointer::PointerExt;
1523
use std::mem::{size_of, ManuallyDrop};
@@ -2585,30 +2593,16 @@ where
25852593

25862594
let dim = self.parts.dim.slice_mut();
25872595
let strides = self.parts.strides.slice_mut();
2588-
let axes = axes.slice();
2589-
2590-
// The cycle detection is done using a bitmask to track visited positions.
2591-
// For example, axes from [0,1,2] to [2, 0, 1]
2592-
// For axis values [1, 0, 2]:
2593-
// 1 << 1 // 0b0001 << 1 = 0b0010 (decimal 2)
2594-
// 1 << 0 // 0b0001 << 0 = 0b0001 (decimal 1)
2595-
// 1 << 2 // 0b0001 << 2 = 0b0100 (decimal 4)
2596-
//
2597-
// Each axis gets its own unique bit position in the bitmask:
2598-
// - Axis 0: bit 0 (rightmost)
2599-
// - Axis 1: bit 1
2600-
// - Axis 2: bit 2
2601-
//
2602-
let mut visited = 0usize;
2603-
for (new_axis, &axis) in axes.iter().enumerate() {
2604-
if (visited & (1 << axis)) != 0 {
2605-
continue;
2606-
}
26072596

2608-
dim.swap(axis, new_axis);
2609-
strides.swap(axis, new_axis);
2610-
2611-
visited |= (1 << axis) | (1 << new_axis);
2597+
for i in 0..axes.ndim() {
2598+
let mut index = axes[i];
2599+
while index < i {
2600+
index = axes[index];
2601+
}
2602+
if index != i {
2603+
dim.swap(i, index);
2604+
strides.swap(i, index);
2605+
}
26122606
}
26132607
}
26142608

@@ -3614,4 +3608,24 @@ mod tests
36143608
let result_slice = empty_slice.partition(0, Axis(0));
36153609
assert_eq!(result_slice.shape(), &[0, 3]);
36163610
}
3611+
3612+
/// Regression test for permute_axes
3613+
#[test]
3614+
fn test_permute_axes_regression()
3615+
{
3616+
let mut a = Array4::<u8>::zeros((1, 2, 3, 4));
3617+
a.permute_axes([3, 0, 1, 2]);
3618+
assert_eq!(a.shape(), &[4, 1, 2, 3]);
3619+
}
3620+
}
3621+
3622+
#[cfg(test)]
3623+
#[cfg_attr(miri, ignore)]
3624+
proptest! {
3625+
#[test]
3626+
fn test_permute_axes_6d(p in Just([0, 1, 2, 3, 4, 5]).prop_shuffle()) {
3627+
let mut arr: Array6<usize> = Array6::zeros((0, 1, 2, 3, 4, 5));
3628+
arr.permute_axes(p.clone());
3629+
prop_assert_eq!(arr.shape(), p);
3630+
}
36173631
}

0 commit comments

Comments
 (0)