Skip to content

Commit b9ca65b

Browse files
dmjioclaude
andcommitted
fix|test|doc: Correct by-key reduction output dtypes, expand tests and docs
Fix countByKey/allTrueByKey/anyTrueByKey return types to reflect the actual ArrayFire output dtype (Word32/CBool) rather than the input value type, preventing host over-reads on toList. Add property tests for by-key reductions, vector round-trips, and bitNot involution/complement. Document the FFI marshalling combinators, Eq/Num Array instances, and several API functions. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 83dd090 commit b9ca65b

11 files changed

Lines changed: 217 additions & 14 deletions

File tree

src/ArrayFire/Algorithm.hs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,8 @@ maxByKey keys vals (fromIntegral -> dim) =
757757
op2p2kv keys vals (\ko vo k v -> af_max_by_key ko vo k v dim)
758758

759759
-- | True if all values are true within each key group.
760+
--
761+
-- The value output is always boolean (@b8@) regardless of the input value type.
760762
allTrueByKey
761763
:: AFType a
762764
=> Array Int
@@ -765,11 +767,13 @@ allTrueByKey
765767
-- ^ Values array (treated as boolean)
766768
-> Int
767769
-- ^ Dimension
768-
-> (Array Int, Array a)
770+
-> (Array Int, Array CBool)
769771
allTrueByKey keys vals (fromIntegral -> dim) =
770772
op2p2kv keys vals (\ko vo k v -> af_all_true_by_key ko vo k v dim)
771773

772774
-- | True if any value is true within each key group.
775+
--
776+
-- The value output is always boolean (@b8@) regardless of the input value type.
773777
anyTrueByKey
774778
:: AFType a
775779
=> Array Int
@@ -778,11 +782,13 @@ anyTrueByKey
778782
-- ^ Values array (treated as boolean)
779783
-> Int
780784
-- ^ Dimension
781-
-> (Array Int, Array a)
785+
-> (Array Int, Array CBool)
782786
anyTrueByKey keys vals (fromIntegral -> dim) =
783787
op2p2kv keys vals (\ko vo k v -> af_any_true_by_key ko vo k v dim)
784788

785789
-- | Count non-zero values within each key group.
790+
--
791+
-- The value output is always @u32@ regardless of the input value type.
786792
countByKey
787793
:: AFType a
788794
=> Array Int
@@ -791,6 +797,6 @@ countByKey
791797
-- ^ Values array
792798
-> Int
793799
-- ^ Dimension
794-
-> (Array Int, Array a)
800+
-> (Array Int, Array Word32)
795801
countByKey keys vals (fromIntegral -> dim) =
796802
op2p2kv keys vals (\ko vo k v -> af_count_by_key ko vo k v dim)

src/ArrayFire/Arith.hs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,7 +1299,8 @@ atan2Batched x y (fromIntegral . fromEnum -> batch) = do
12991299
x `op2` y $ \arr arr1 arr2 ->
13001300
af_atan2 arr arr1 arr2 batch
13011301

1302-
-- | Take the cplx2 of all values in an 'Array'
1302+
-- | Construct a complex 'Array' from two real 'Array's, taking the first as the
1303+
-- real part and the second as the imaginary part.
13031304
--
13041305
-- >>> A.cplx2 (A.vector @Int 10 [1..]) (A.vector @Int 10 [1..])
13051306
-- ArrayFire Array
@@ -1321,12 +1322,13 @@ cplx2
13211322
-> Array a
13221323
-- ^ Second input (imaginary part)
13231324
-> Array (Complex a)
1324-
-- ^ Result of cplx2
1325+
-- ^ Complex result with the inputs as real and imaginary parts
13251326
cplx2 x y =
13261327
x `op2` y $ \arr arr1 arr2 ->
13271328
af_cplx2 arr arr1 arr2 1
13281329

1329-
-- | Take the cplx2Batched of all values in an 'Array'
1330+
-- | Construct a complex 'Array' from two real 'Array's (real and imaginary
1331+
-- parts), with explicit control over batched broadcasting of the inputs.
13301332
--
13311333
-- >>> A.cplx2Batched (A.vector @Int 10 [1..]) (A.vector @Int 10 [1..]) True
13321334
-- ArrayFire Array
@@ -1348,9 +1350,9 @@ cplx2Batched
13481350
-> Array a
13491351
-- ^ Second input (imaginary part)
13501352
-> Bool
1351-
-- ^ Use batch
1353+
-- ^ Whether to enable batched broadcasting of the inputs
13521354
-> Array (Complex a)
1353-
-- ^ Result of cplx2
1355+
-- ^ Complex result with the inputs as real and imaginary parts
13541356
cplx2Batched x y (fromIntegral . fromEnum -> batch) = do
13551357
x `op2` y $ \arr arr1 arr2 ->
13561358
af_cplx2 arr arr1 arr2 batch

src/ArrayFire/Array.hs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ mkArray dims xs =
211211

212212
-- | Constructs an 'Array' from a 'Storable' 'Vector', avoiding the intermediate list allocation of 'mkArray'.
213213
--
214-
-- The vector's pinned buffer is passed directly to @af_create_array@.
214+
-- The vector's contiguous buffer is handed straight to @af_create_array@, which
215+
-- copies it into the 'Array' (and uploads to device memory on GPU backends), so
216+
-- no intermediate Haskell list is built.
215217
-- Throws 'AFException' if the vector length does not match the product of the given dimensions.
216218
--
217219
-- >>> fromVector @Double [3] (Data.Vector.Storable.fromList [1,2,3])

src/ArrayFire/Data.hs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,10 @@ joinMany (fromIntegral -> n) (fmap (\(Array fp) -> fp) -> arrays) = unsafePerfor
396396
Array <$>
397397
newForeignPtr af_release_array_finalizer newPtr
398398

399+
-- | Marshals a list of 'ForeignPtr' into a temporary, contiguous C array of
400+
-- raw pointers, keeping every 'ForeignPtr' alive for the duration of the
401+
-- action. The continuation receives the number of pointers and a pointer to
402+
-- the array.
399403
withManyForeignPtr :: [ForeignPtr a] -> (Int -> Ptr (Ptr a) -> IO b) -> IO b
400404
withManyForeignPtr fptrs action = go [] fptrs
401405
where

0 commit comments

Comments
 (0)