Skip to content

Commit 16d9472

Browse files
dmjioclaude
andcommitted
Fix return types: CBool for boolean ops, Complex for cplx/real/imag
- isZero, isInf, isNaN: Array a -> Array CBool (af_is* always emits u8) - allTrue, anyTrue: Array a -> Int -> Array CBool (af_all/any_true emits u8) - where': Array a -> Array Word32 (af_where emits u32 indices) - cplx, cplx2, cplx2Batched: return Array (Complex a), not Array a - real, imag: simplified to (RealFloat a, AFType a, AFType (Complex a)) => Array (Complex a) -> Array a; previous signature was unlinked (a, b) - Update tests to match corrected return types Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent d99615f commit 16d9472

4 files changed

Lines changed: 46 additions & 45 deletions

File tree

src/ArrayFire/Algorithm.hs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
module ArrayFire.Algorithm where
2828

2929
import Data.Word (Word32)
30+
import Foreign.C.Types (CBool)
3031

3132
import ArrayFire.FFI
3233
import ArrayFire.Internal.Algorithm
@@ -154,13 +155,13 @@ max x (fromIntegral -> n) = x `op1` (\p a -> af_max p a n)
154155
-- [1 1 1 1]
155156
-- 0
156157
allTrue
157-
:: forall a. AFType a
158+
:: AFType a
158159
=> Array a
159160
-- ^ Array input
160161
-> Int
161162
-- ^ Dimension along which to see if all elements are True
162-
-> Array a
163-
-- ^ Will contain the maximum of all values in the input array along dim
163+
-> Array CBool
164+
-- ^ Will contain 1 where all elements along dim are true, 0 otherwise
164165
allTrue x (fromIntegral -> n) =
165166
x `op1` (\p a -> af_all_true p a n)
166167

@@ -171,13 +172,13 @@ allTrue x (fromIntegral -> n) =
171172
-- [1 1 1 1]
172173
-- 0
173174
anyTrue
174-
:: forall a . AFType a
175+
:: AFType a
175176
=> Array a
176177
-- ^ Array input
177178
-> Int
178-
-- ^ Dimension along which to see if all elements are True
179-
-> Array a
180-
-- ^ Returns if all elements are true
179+
-- ^ Dimension along which to see if any elements are True
180+
-> Array CBool
181+
-- ^ Will contain 1 where any element along dim is true, 0 otherwise
181182
anyTrue x (fromIntegral -> n) =
182183
(x `op1` (\p a -> af_any_true p a n))
183184

@@ -473,8 +474,8 @@ where'
473474
:: AFType a
474475
=> Array a
475476
-- ^ Is the input array.
476-
-> Array a
477-
-- ^ will contain indices where input array is non-zero
477+
-> Array Word32
478+
-- ^ Indices where input array is non-zero
478479
where' = (`op1` af_where)
479480

480481
-- | First order numerical difference along specified dimension.

src/ArrayFire/Arith.hs

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
--------------------------------------------------------------------------------
2929
module ArrayFire.Arith where
3030

31-
import Prelude (Bool(..), ($), (.), flip, fromEnum, fromIntegral, Real, RealFrac)
31+
import Prelude (Bool(..), ($), (.), flip, fromEnum, fromIntegral, Real, RealFloat)
3232

3333
import Data.Coerce
3434
import Data.Proxy
@@ -1315,12 +1315,12 @@ atan2Batched x y (fromIntegral . fromEnum -> batch) = do
13151315
-- (9.0000,9.0000)
13161316
-- (10.0000,10.0000)
13171317
cplx2
1318-
:: AFType a
1318+
:: (RealFloat a, AFType a, AFType (Complex a))
13191319
=> Array a
1320-
-- ^ First input
1321-
-> Array a
1322-
-- ^ Second input
1320+
-- ^ First input (real part)
13231321
-> Array a
1322+
-- ^ Second input (imaginary part)
1323+
-> Array (Complex a)
13241324
-- ^ Result of cplx2
13251325
cplx2 x y =
13261326
x `op2` y $ \arr arr1 arr2 ->
@@ -1342,14 +1342,14 @@ cplx2 x y =
13421342
-- (9.0000,9.0000)
13431343
-- (10.0000,10.0000)
13441344
cplx2Batched
1345-
:: AFType a
1345+
:: (RealFloat a, AFType a, AFType (Complex a))
13461346
=> Array a
1347-
-- ^ First input
1347+
-- ^ First input (real part)
13481348
-> Array a
1349-
-- ^ Second input
1349+
-- ^ Second input (imaginary part)
13501350
-> Bool
13511351
-- ^ Use batch
1352-
-> Array a
1352+
-> Array (Complex a)
13531353
-- ^ Result of cplx2
13541354
cplx2Batched x y (fromIntegral . fromEnum -> batch) = do
13551355
x `op2` y $ \arr arr1 arr2 ->
@@ -1371,11 +1371,11 @@ cplx2Batched x y (fromIntegral . fromEnum -> batch) = do
13711371
-- (9.0000,0.0000)
13721372
-- (10.0000,0.0000)
13731373
cplx
1374-
:: AFType a
1374+
:: (RealFloat a, AFType a, AFType (Complex a))
13751375
=> Array a
13761376
-- ^ Input array
1377-
-> Array a
1378-
-- ^ Result of calling 'atan'
1377+
-> Array (Complex a)
1378+
-- ^ Complex array with input as real part and zero imaginary part
13791379
cplx = flip op1 af_cplx
13801380

13811381
-- | Execute real
@@ -1385,11 +1385,11 @@ cplx = flip op1 af_cplx
13851385
-- [1 1 1 1]
13861386
-- 10.0000
13871387
real
1388-
:: (AFType a, AFType (Complex b), RealFrac a, RealFrac b)
1389-
=> Array (Complex b)
1388+
:: (RealFloat a, AFType a, AFType (Complex a))
1389+
=> Array (Complex a)
13901390
-- ^ Input array
13911391
-> Array a
1392-
-- ^ Result of calling 'real'
1392+
-- ^ Real part of each element
13931393
real = flip op1 af_real
13941394

13951395
-- | Execute imag
@@ -1399,11 +1399,11 @@ real = flip op1 af_real
13991399
-- [1 1 1 1]
14001400
-- 11.0000
14011401
imag
1402-
:: (AFType a, AFType (Complex b), RealFrac a, RealFrac b)
1403-
=> Array (Complex b)
1402+
:: (RealFloat a, AFType a, AFType (Complex a))
1403+
=> Array (Complex a)
14041404
-- ^ Input array
14051405
-> Array a
1406-
-- ^ Result of calling 'imag'
1406+
-- ^ Imaginary part of each element
14071407
imag = flip op1 af_imag
14081408

14091409
-- | Execute conjg
@@ -2043,7 +2043,7 @@ isZero
20432043
:: AFType a
20442044
=> Array a
20452045
-- ^ Input array
2046-
-> Array a
2046+
-> Array CBool
20472047
-- ^ Result of calling 'isZero'
20482048
isZero = (`op1` af_iszero)
20492049

@@ -2066,7 +2066,7 @@ isInf
20662066
:: (Real a, AFType a)
20672067
=> Array a
20682068
-- ^ Input array
2069-
-> Array a
2069+
-> Array CBool
20702070
-- ^ will contain 1's where input is Inf or -Inf, and 0 otherwise.
20712071
isInf = (`op1` af_isinf)
20722072

@@ -2086,9 +2086,9 @@ isInf = (`op1` af_isinf)
20862086
-- 1
20872087
-- 1
20882088
isNaN
2089-
:: forall a. (AFType a, Real a)
2089+
:: (AFType a, Real a)
20902090
=> Array a
20912091
-- ^ Input array
2092-
-> Array a
2092+
-> Array CBool
20932093
-- ^ Will contain 1's where input is NaN, and 0 otherwise.
20942094
isNaN = (`op1` af_isnan)

test/ArrayFire/AlgorithmSpec.hs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ spec =
9494
A.max (A.vector @(A.Complex Float) 3 [3 A.:+ 4, 1 A.:+ 0, 2 A.:+ 2]) 0 `shouldBe` A.scalar (3 A.:+ 4)
9595
A.max (A.vector @A.CBool 5 [0,1,1,0,1]) 0 `shouldBe` 1
9696
it "Should find if all elements are true along dimension" $ do
97-
A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` 1
98-
A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1
99-
A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0
97+
A.allTrue (A.vector @Double 5 (repeat 12.0)) 0 `shouldBe` A.scalar @A.CBool 1
98+
A.allTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` A.scalar @A.CBool 1
99+
A.allTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` A.scalar @A.CBool 0
100100
it "Should find if any elements are true along dimension" $ do
101-
A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 1
102-
A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` 1
103-
A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` 0
101+
A.anyTrue (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` A.scalar @A.CBool 1
102+
A.anyTrue (A.vector @Int 5 (repeat 23)) 0 `shouldBe` A.scalar @A.CBool 1
103+
A.anyTrue (A.vector @A.CBool 5 (repeat 0)) 0 `shouldBe` A.scalar @A.CBool 0
104104
it "Should get count of all elements" $ do
105105
A.count (A.vector @Int 5 (repeat 1)) 0 `shouldBe` 5
106106
A.count (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 5
@@ -205,7 +205,7 @@ spec =
205205
describe "where'" $ do
206206
it "returns indices of nonzero elements" $ do
207207
A.where' (A.vector @Double 5 [0,1,0,2,0])
208-
`shouldBe` A.vector @Double 2 [1,3]
208+
`shouldBe` A.vector @A.Word32 2 [1,3]
209209
it "returns empty array when all elements are zero" $ do
210210
A.getDims (A.where' (A.vector @Double 3 [0,0,0]))
211211
`shouldBe` (0,1,1,1)

test/ArrayFire/ArithSpec.hs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,15 @@ spec =
140140
clamp (scalar @Int 2) (scalar @Int 1) (scalar @Int 3)
141141
`shouldBe` 2
142142
it "Should check if an array has positive or negative infinities" $ do
143-
isInf (scalar @Double (1 / 0)) `shouldBe` scalar @Double 1
144-
isInf (scalar @Double 10) `shouldBe` scalar @Double 0
143+
isInf (scalar @Double (1 / 0)) `shouldBe` scalar @CBool 1
144+
isInf (scalar @Double 10) `shouldBe` scalar @CBool 0
145145
it "Should check if an array has any NaN values" $ do
146-
ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1
147-
ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @Double 0
146+
ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @CBool 1
147+
ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @CBool 0
148148
it "Should check if an array has any Zero values" $ do
149-
isZero (scalar @Double (acos 2)) `shouldBe` scalar @Double 0
150-
isZero (scalar @Double 0) `shouldBe` scalar @Double 1
151-
isZero (scalar @Double 1) `shouldBe` scalar @Double 0
149+
isZero (scalar @Double (acos 2)) `shouldBe` scalar @CBool 0
150+
isZero (scalar @Double 0) `shouldBe` scalar @CBool 1
151+
isZero (scalar @Double 1) `shouldBe` scalar @CBool 0
152152

153153
prop "Floating @Float (exp)" $ \(x :: Float) -> exp `shouldMatchBuiltin` exp $ x
154154
prop "Floating @Float (log)" $ \(x :: Float) -> log `shouldMatchBuiltin` log $ x

0 commit comments

Comments
 (0)