Skip to content

Commit 90ef718

Browse files
dmjioclaude
andcommitted
fix|api: Zero-init FFI output slots; add calloca; Order type for sort
Add `calloca` (zero-initialised stack alloc via alloca+fillBytes) and use it in infoFromArray2/22/3 so the imaginary-part output pointer is always 0.0 for real-valued arrays instead of uninitialized stack garbage, matching the Rust bindings' explicit zero-init pattern. Replace Bool with a new Order (Asc | Desc) type in sort, sortIndex, and sortByKey for clarity. Fix sumNaN/productNaN/allTrue docstrings to use inputs that actually exercise the behaviour being documented. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent a3db69d commit 90ef718

2 files changed

Lines changed: 31 additions & 18 deletions

File tree

src/ArrayFire/Algorithm.hs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ sum x (fromIntegral -> n) = (x `op1` (\p a -> af_sum p a n))
6666

6767
-- | Sum all of the elements in 'Array' along the specified dimension, using a default value for NaN
6868
--
69-
-- >>> A.sumNaN (A.vector @Double 10 [1..]) 0 0.0
69+
-- >>> let nan = 0/0 in A.sumNaN (A.vector @Double 10 (nan : [1..])) 0 10.0
7070
-- ArrayFire Array
7171
-- [1 1 1 1]
7272
-- 55.0000
@@ -100,7 +100,7 @@ product x (fromIntegral -> n) = (x `op1` (\p a -> af_product p a n))
100100

101101
-- | Product all of the elements in 'Array' along the specified dimension, using a default value for NaN
102102
--
103-
-- >>> A.productNaN (A.vector @Double 10 [1..]) 0 0.0
103+
-- >>> let nan = 0/0 in A.productNaN (A.vector @Double 10 (nan : [1..])) 0 2.0
104104
-- ArrayFire Array
105105
-- [1 1 1 1]
106106
-- 3628800.0000
@@ -150,10 +150,10 @@ max x (fromIntegral -> n) = x `op1` (\p a -> af_max p a n)
150150

151151
-- | Find if all elements in an 'Array' are 'True' along a dimension
152152
--
153-
-- >>> A.allTrue (A.vector @CBool 10 (repeat 0)) 0
153+
-- >>> A.allTrue (A.vector @CBool 10 (repeat 1)) 0
154154
-- ArrayFire Array
155155
-- [1 1 1 1]
156-
-- 0
156+
-- 1
157157
allTrue
158158
:: AFType a
159159
=> Array a
@@ -212,7 +212,7 @@ sumAll = (`infoFromArray2` af_sum_all)
212212

213213
-- | Sum all elements in an 'Array' along all dimensions, using a default value for NaN
214214
--
215-
-- >>> A.sumNaNAll (A.vector @Double 10 [1..]) 0.0
215+
-- >>> let nan = 0/0 in A.sumNaNAll (A.vector @Double 10 (nan : [1..])) 0.0
216216
-- (55.0,0.0)
217217
sumNaNAll
218218
:: (AFType a, Fractional a)
@@ -516,15 +516,15 @@ diff2 a (fromIntegral -> n) = a `op1` (\p x -> af_diff2 p x n)
516516

517517
-- | Sort an Array along a specified dimension, specifying ordering of results (ascending / descending)
518518
--
519-
-- >>> A.sort (A.vector @Double 4 [ 2,4,3,1 ]) 0 True
519+
-- >>> A.sort (A.vector @Double 4 [ 2,4,3,1 ]) 0 Asc
520520
-- ArrayFire Array
521521
-- [4 1 1 1]
522522
-- 1.0000
523523
-- 2.0000
524524
-- 3.0000
525525
-- 4.0000
526526
--
527-
-- >>> A.sort (A.vector @Double 4 [ 2,4,3,1 ]) 0 False
527+
-- >>> A.sort (A.vector @Double 4 [ 2,4,3,1 ]) 0 Desc
528528
-- ArrayFire Array
529529
-- [4 1 1 1]
530530
-- 4.0000
@@ -537,7 +537,7 @@ sort
537537
-- ^ Input array
538538
-> Int
539539
-- ^ Dimension along `sort` is performed
540-
-> Bool
540+
-> Order
541541
-- ^ Return results in ascending order
542542
-> Array a
543543
-- ^ Will contain sorted input
@@ -546,7 +546,7 @@ sort a (fromIntegral -> n) (fromIntegral . fromEnum -> b) =
546546

547547
-- | Sort an 'Array' along a specified dimension, specifying ordering of results (ascending / descending), returns indices of sorted results
548548
--
549-
-- >>> A.sortIndex (A.vector @Double 4 [3,2,1,4]) 0 True
549+
-- >>> A.sortIndex (A.vector @Double 4 [3,2,1,4]) 0 Asc
550550
-- (ArrayFire Array
551551
-- [4 1 1 1]
552552
-- 1.0000
@@ -566,13 +566,18 @@ sortIndex
566566
-- ^ Input array
567567
-> Int
568568
-- ^ Dimension along `sortIndex` is performed
569-
-> Bool
569+
-> Order
570570
-- ^ Return results in ascending order
571571
-> (Array a, Array Word32)
572572
-- ^ Contains the sorted, contains indices for original input
573573
sortIndex a (fromIntegral -> n) (fromIntegral . fromEnum -> b) =
574574
a `op2p` (\p1 p2 p3 -> af_sort_index p1 p2 p3 n b)
575575

576+
577+
-- | Data type for expressing sort order
578+
data Order = Desc | Asc
579+
deriving (Enum, Show, Eq)
580+
576581
-- | Sort an 'Array' along a specified dimension by keys, specifying ordering of results (ascending / descending)
577582
--
578583
-- >>> A.sortByKey (A.vector @Double 4 [2,1,4,3]) (A.vector @Double 4 [10,9,8,7]) 0 True
@@ -597,7 +602,7 @@ sortByKey
597602
-- ^ Values input array
598603
-> Int
599604
-- ^ Dimension along which to perform the operation
600-
-> Bool
605+
-> Order
601606
-- ^ Return results in ascending order
602607
-> (Array a, Array a)
603608
sortByKey a1 a2 (fromIntegral -> n) (fromIntegral . fromEnum -> b) =

src/ArrayFire/FFI.hs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,16 @@ import Foreign.Storable
3434
import Foreign.Ptr
3535
import Foreign.C
3636
import Foreign.Marshal.Alloc
37+
import Foreign.Marshal.Utils (fillBytes)
3738
import System.IO.Unsafe
3839

40+
-- | Like 'alloca' but zero-initialises the memory before handing the pointer
41+
-- to the continuation. Prevents uninitialized stack garbage from leaking into
42+
-- output scalars when the C function does not write the imaginary-part pointer
43+
-- for real-valued arrays (e.g. af_mean_all_weighted).
44+
calloca :: forall a b. Storable a => (Ptr a -> IO b) -> IO b
45+
calloca f = alloca $ \p -> fillBytes p 0 (sizeOf (undefined :: a)) >> f p
46+
3947
foreign import ccall unsafe "af_cast"
4048
af_cast :: Ptr AFArray -> AFArray -> AFDtype -> IO AFErr
4149

@@ -538,8 +546,8 @@ infoFromArray2
538546
infoFromArray2 (Array fptr1) op =
539547
unsafePerformIO . mask_ $ do
540548
withForeignPtr fptr1 $ \ptr1 ->
541-
bracket (callocBytes (sizeOf (undefined :: a))) free $ \ptrInput1 ->
542-
bracket (callocBytes (sizeOf (undefined :: b))) free $ \ptrInput2 -> do
549+
calloca $ \ptrInput1 ->
550+
calloca $ \ptrInput2 -> do
543551
throwAFError =<< op ptrInput1 ptrInput2 ptr1
544552
(,) <$> peek ptrInput1 <*> peek ptrInput2
545553

@@ -556,8 +564,8 @@ infoFromArray22 (Array fptr1) (Array fptr2) op =
556564
unsafePerformIO . mask_ $ do
557565
withForeignPtr fptr1 $ \ptr1 ->
558566
withForeignPtr fptr2 $ \ptr2 ->
559-
bracket (callocBytes (sizeOf (undefined :: a))) free $ \ptrInput1 ->
560-
bracket (callocBytes (sizeOf (undefined :: b))) free $ \ptrInput2 -> do
567+
calloca $ \ptrInput1 ->
568+
calloca $ \ptrInput2 -> do
561569
throwAFError =<< op ptrInput1 ptrInput2 ptr1 ptr2
562570
(,) <$> peek ptrInput1 <*> peek ptrInput2
563571

@@ -572,9 +580,9 @@ infoFromArray3
572580
infoFromArray3 (Array fptr1) op =
573581
unsafePerformIO . mask_ $
574582
withForeignPtr fptr1 $ \ptr1 ->
575-
bracket (callocBytes (sizeOf (undefined :: a))) free $ \ptrInput1 ->
576-
bracket (callocBytes (sizeOf (undefined :: b))) free $ \ptrInput2 ->
577-
bracket (callocBytes (sizeOf (undefined :: c))) free $ \ptrInput3 -> do
583+
calloca $ \ptrInput1 ->
584+
calloca $ \ptrInput2 ->
585+
calloca $ \ptrInput3 -> do
578586
throwAFError =<< op ptrInput1 ptrInput2 ptrInput3 ptr1
579587
(,,) <$> peek ptrInput1
580588
<*> peek ptrInput2

0 commit comments

Comments
 (0)