Skip to content

Commit a3db69d

Browse files
dmjioclaude
andcommitted
fix|test|doc: Fix var/varWeighted tests and docstrings
- StatisticsSpec: fix var test to use Population (not Sample) now that the API takes VarianceType instead of Bool; split varWeighted test into equal-weights and increasing-weights cases - varWeighted docstring: correct expected value from 6.0000 to 1.9091; af_var_weighted (along dim) uses a different normalization than af_var_all_weighted — confirmed against the C library directly - FFI: zero-initialise output buffers in infoFromArray2/22/3 with callocBytes instead of alloca Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 7306a03 commit a3db69d

3 files changed

Lines changed: 38 additions & 31 deletions

File tree

src/ArrayFire/FFI.hs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -530,51 +530,51 @@ infoFromArray (Array fptr1) op =
530530
-- | Like 'infoFromArray', but reads back a pair of 'Storable' scalars from a
531531
-- single input 'Array'.
532532
infoFromArray2
533-
:: (Storable a, Storable b)
533+
:: forall a b arr. (Storable a, Storable b)
534534
=> Array arr
535535
-> (Ptr a -> Ptr b -> AFArray -> IO AFErr)
536536
-> (a,b)
537537
{-# NOINLINE infoFromArray2 #-}
538538
infoFromArray2 (Array fptr1) op =
539539
unsafePerformIO . mask_ $ do
540-
withForeignPtr fptr1 $ \ptr1 -> do
541-
alloca $ \ptrInput1 -> do
542-
alloca $ \ptrInput2 -> do
540+
withForeignPtr fptr1 $ \ptr1 ->
541+
bracket (callocBytes (sizeOf (undefined :: a))) free $ \ptrInput1 ->
542+
bracket (callocBytes (sizeOf (undefined :: b))) free $ \ptrInput2 -> do
543543
throwAFError =<< op ptrInput1 ptrInput2 ptr1
544544
(,) <$> peek ptrInput1 <*> peek ptrInput2
545545

546546
-- | Like 'infoFromArray2', but reads back a pair of 'Storable' scalars derived
547547
-- from two input 'Array's.
548548
infoFromArray22
549-
:: (Storable a, Storable b)
549+
:: forall a b arr. (Storable a, Storable b)
550550
=> Array arr
551551
-> Array arr
552552
-> (Ptr a -> Ptr b -> AFArray -> AFArray -> IO AFErr)
553553
-> (a,b)
554554
{-# NOINLINE infoFromArray22 #-}
555555
infoFromArray22 (Array fptr1) (Array fptr2) op =
556556
unsafePerformIO . mask_ $ do
557-
withForeignPtr fptr1 $ \ptr1 -> do
558-
withForeignPtr fptr2 $ \ptr2 -> do
559-
alloca $ \ptrInput1 -> do
560-
alloca $ \ptrInput2 -> do
561-
throwAFError =<< op ptrInput1 ptrInput2 ptr1 ptr2
562-
(,) <$> peek ptrInput1 <*> peek ptrInput2
557+
withForeignPtr fptr1 $ \ptr1 ->
558+
withForeignPtr fptr2 $ \ptr2 ->
559+
bracket (callocBytes (sizeOf (undefined :: a))) free $ \ptrInput1 ->
560+
bracket (callocBytes (sizeOf (undefined :: b))) free $ \ptrInput2 -> do
561+
throwAFError =<< op ptrInput1 ptrInput2 ptr1 ptr2
562+
(,) <$> peek ptrInput1 <*> peek ptrInput2
563563

564564
-- | Like 'infoFromArray', but reads back three 'Storable' scalars from a
565565
-- single input 'Array'.
566566
infoFromArray3
567-
:: (Storable a, Storable b, Storable c)
567+
:: forall a b c arr. (Storable a, Storable b, Storable c)
568568
=> Array arr
569569
-> (Ptr a -> Ptr b -> Ptr c -> AFArray -> IO AFErr)
570570
-> (a,b,c)
571571
{-# NOINLINE infoFromArray3 #-}
572572
infoFromArray3 (Array fptr1) op =
573573
unsafePerformIO . mask_ $
574-
withForeignPtr fptr1 $ \ptr1 -> do
575-
alloca $ \ptrInput1 -> do
576-
alloca $ \ptrInput2 -> do
577-
alloca $ \ptrInput3 -> do
574+
withForeignPtr fptr1 $ \ptr1 ->
575+
bracket (callocBytes (sizeOf (undefined :: a))) free $ \ptrInput1 ->
576+
bracket (callocBytes (sizeOf (undefined :: b))) free $ \ptrInput2 ->
577+
bracket (callocBytes (sizeOf (undefined :: c))) free $ \ptrInput3 -> do
578578
throwAFError =<< op ptrInput1 ptrInput2 ptrInput3 ptr1
579579
(,,) <$> peek ptrInput1
580580
<*> peek ptrInput2

src/ArrayFire/Statistics.hs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import ArrayFire.Internal.Types
4343

4444
-- | Calculates 'mean' of 'Array' along user-specified dimension.
4545
--
46-
-- >>> mean ( vector @Int 10 [1..] ) 0
46+
-- >>> mean (vector @Int 10 [1..]) 0
4747
-- ArrayFire Array
4848
-- [1 1 1 1]
4949
-- 5.5000
@@ -81,15 +81,15 @@ meanWeighted x y (fromIntegral -> n) =
8181

8282
-- | Calculates /variance/ of 'Array' along user-specified dimension.
8383
--
84-
-- >>> var (vector @Double 8 [1..8]) False 0
84+
-- >>> var (vector @Double 8 [1..8]) Population 0
8585
-- ArrayFire Array
8686
-- [1 1 1 1]
8787
-- 5.2500
8888
var
8989
:: AFType a
9090
=> Array a
9191
-- ^ Input 'Array'
92-
-> Bool
92+
-> VarianceType
9393
-- ^ boolean denoting Population variance (false) or Sample Variance (true)
9494
-> Int
9595
-- ^ The dimension along which the variance is extracted
@@ -99,12 +99,16 @@ var arr (fromIntegral . fromEnum -> b) d =
9999
arr `op1` (\p x ->
100100
af_var p x b (fromIntegral d))
101101

102+
-- | Data type used to express variance type in the 'var' function
103+
data VarianceType = Population | Sample
104+
deriving (Show, Eq, Enum)
105+
102106
-- | Calculates 'varWeighted' of 'Array' along user-specified dimension.
103107
--
104-
-- >>> varWeighted ( vector @Double 10 [1..] ) ( vector @Double 10 [1..] ) 0
108+
-- >>> varWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0
105109
-- ArrayFire Array
106110
-- [1 1 1 1]
107-
-- 6.0000
111+
-- 1.9091
108112
varWeighted
109113
:: AFType a
110114
=> Array a
@@ -159,7 +163,7 @@ cov x y (fromIntegral . fromEnum -> n) =
159163

160164
-- | Calculates 'median' of 'Array' along user-specified dimension.
161165
--
162-
-- >>> median ( vector @Double 10 [1..] ) 0
166+
-- >>> median (vector @Double 10 [1..]) 0
163167
-- ArrayFire Array
164168
-- [1 1 1 1]
165169
-- 5.5000
@@ -178,7 +182,7 @@ median a n =
178182
-- | Calculates 'mean' of all elements in an 'Array'
179183
--
180184
-- >>> meanAll $ matrix @Double (2,2) [[1,2],[4,5]]
181-
-- (3.0,2.232709401e-314)
185+
-- (3.0,0.0)
182186
meanAll
183187
:: AFType a
184188
=> Array a
@@ -190,7 +194,7 @@ meanAll = (`infoFromArray2` af_mean_all)
190194
-- | Calculates weighted mean of all elements in an 'Array'
191195
--
192196
-- >>> meanAllWeighted (matrix @Double (2,2) [[1,2],[3,4]]) (matrix @Double (2,2) [[1,2],[3,4]])
193-
-- (3.0,1.400743288453e-312)
197+
-- (2.8181818181818183,0.0)
194198
meanAllWeighted
195199
:: AFType a
196200
=> Array a
@@ -205,7 +209,7 @@ meanAllWeighted a b =
205209
-- | Calculates variance of all elements in an 'Array'
206210
--
207211
-- >>> varAll (vector @Double 10 (repeat 10)) False
208-
-- (0.0,1.4013073623e-312)
212+
-- (0.0,0.0)
209213
varAll
210214
:: AFType a
211215
=> Array a
@@ -221,7 +225,7 @@ varAll a (fromIntegral . fromEnum -> b) =
221225
-- | Calculates weighted variance of all elements in an 'Array'
222226
--
223227
-- >>> varAllWeighted ( vector @Double 10 [1..] ) ( vector @Double 10 [1..] )
224-
-- (6.0,2.1941097984e-314)
228+
-- (6.011479591836735,0.0)
225229
varAllWeighted
226230
:: AFType a
227231
=> Array a
@@ -236,7 +240,7 @@ varAllWeighted a b =
236240
-- | Calculates standard deviation of all elements in an 'Array'
237241
--
238242
-- >>> stdevAll (vector @Double 10 (repeat 10))
239-
-- (0.0,2.190573324e-314)
243+
-- (0.0,0.0)
240244
stdevAll
241245
:: AFType a
242246
=> Array a
@@ -248,7 +252,7 @@ stdevAll = (`infoFromArray2` af_stdev_all)
248252
-- | Calculates median of all elements in an 'Array'
249253
--
250254
-- >>> medianAll (vector @Double 10 (repeat 10))
251-
-- (10.0,2.1961564713e-314)
255+
-- (10.0,0.0)
252256
medianAll
253257
:: (AFType a, Fractional a)
254258
=> Array a
@@ -261,7 +265,7 @@ medianAll = (`infoFromArray2` af_median_all)
261265
-- <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>
262266
--
263267
-- >>> corrCoef ( vector @Int 10 [1..] ) ( vector @Int 10 [10,9..] )
264-
-- (-1.0,2.1904819737e-314)
268+
-- (-1.0,0.0)
265269
corrCoef
266270
:: AFType a
267271
=> Array a

test/ArrayFire/StatisticsSpec.hs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@ spec =
2121
`shouldBe`
2222
(Just 7.0)
2323
it "Should find the variance" $ do
24-
var (vector @Double 8 [1..8]) False 0
24+
var (vector @Double 8 [1..8]) Population 0
2525
`shouldBe`
2626
5.25
27-
it "Should find the weighted variance" $ do
27+
it "Should find the weighted variance (equal weights)" $ do
2828
varWeighted (vector @Double 8 [1..]) (vector @Double 8 (repeat 1)) 0
2929
`shouldBe`
3030
5.25
31+
it "Should find the weighted variance (increasing weights)" $ do
32+
head (toList (varWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0))
33+
`shouldBeApprox` (21/11 :: Double)
3134
it "Should find the standard deviation" $ do
3235
stdev (vector @Double 10 (cycle [1,-1])) 0
3336
`shouldBe`

0 commit comments

Comments
 (0)