Skip to content

Commit 7306a03

Browse files
dmjioclaude
andcommitted
test|doc: Guard by-key property tests to n>=2; fix var docstring
ArrayFire's C-level by-key reduction functions (af_sum_by_key, af_max_by_key, af_count_by_key) return AF_ERR_ARG for single-element input arrays. Guard the three property tests with `length pairs >= 2` and add a comment explaining the restriction. Also correct the var docstring example (6.0000 -> 5.2500). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent e409631 commit 7306a03

2 files changed

Lines changed: 6 additions & 4 deletions

File tree

src/ArrayFire/Statistics.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ meanWeighted x y (fromIntegral -> n) =
8484
-- >>> var (vector @Double 8 [1..8]) False 0
8585
-- ArrayFire Array
8686
-- [1 1 1 1]
87-
-- 6.0000
87+
-- 5.2500
8888
var
8989
:: AFType a
9090
=> Array a

test/ArrayFire/AlgorithmSpec.hs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,10 @@ spec =
321321
-- These exercise the op2p2kv marshalling (s32 key cast in, s64 cast out)
322322
-- against a pure contiguous-groupBy reference. Keys are squeezed into a
323323
-- small range so random inputs produce real multi-element runs.
324+
-- Note: ArrayFire's by-key C functions require n >= 2; single-element
325+
-- arrays return ArgError at the C level, so we guard length >= 2.
324326
prop "sumByKey matches a contiguous groupBy reference" $ \(pairs :: [(Int, Double)]) ->
325-
not (null pairs) ==>
327+
length pairs >= 2 ==>
326328
let n = length pairs
327329
keys = map ((`mod` 8) . abs . fst) pairs
328330
vals = map snd pairs
@@ -332,7 +334,7 @@ spec =
332334
&& closeList (A.toList vo) (map (sum . snd) groups)
333335

334336
prop "maxByKey matches per-group maxima" $ \(pairs :: [(Int, Double)]) ->
335-
not (null pairs) ==>
337+
length pairs >= 2 ==>
336338
let n = length pairs
337339
keys = map ((`mod` 8) . abs . fst) pairs
338340
vals = map snd pairs
@@ -344,7 +346,7 @@ spec =
344346
-- countByKey output is u32, not the input dtype. Comparing host values
345347
-- (toList) guards against the result being mistyped as the value dtype.
346348
prop "countByKey matches per-group nonzero counts" $ \(pairs :: [(Int, Double)]) ->
347-
not (null pairs) ==>
349+
length pairs >= 2 ==>
348350
let n = length pairs
349351
keys = map ((`mod` 8) . abs . fst) pairs
350352
vals = map snd pairs

0 commit comments

Comments
 (0)