22{-# LANGUAGE TypeApplications #-}
33module ArrayFire.BLASSpec where
44
5- import ArrayFire hiding (not )
5+ import ArrayFire hiding (not , and , abs , max )
66
77import Data.Complex
88import Test.Hspec
@@ -13,6 +13,22 @@ import Test.Hspec.QuickCheck (prop)
1313mat4 :: [Double ] -> Array Double
1414mat4 xs = mkArray [4 ,4 ] (take 16 (xs ++ repeat 0 ))
1515
16+ -- | Build a length-4 'Double' vector, padding with zeros.
17+ vec4 :: [Double ] -> Array Double
18+ vec4 xs = vector 4 (take 4 (xs ++ repeat 0 ))
19+
20+ -- | Plain matrix product with default (None) operands.
21+ mm :: Array Double -> Array Double -> Array Double
22+ mm a b = (a `matmul` b) None None
23+
24+ -- | Transpose (no conjugation).
25+ tr :: Array Double -> Array Double
26+ tr a = transpose a False
27+
28+ -- | Scale every element of a 4x4 matrix by a constant.
29+ scaleMat :: Double -> Array Double -> Array Double
30+ scaleMat c a = mkArray [4 ,4 ] (map (c * ) (toList a))
31+
1632-- | Element-wise closeness, tolerant of floating-point rounding in BLAS.
1733closeList :: [Double ] -> [Double ] -> Bool
1834closeList as bs =
@@ -83,3 +99,53 @@ spec =
8399 lhs = transpose ((transpose a False `matmul` transpose b False ) None None ) False
84100 rhs = (b `matmul` a) None None
85101 in closeList (toList lhs) (toList rhs)
102+
103+ -- Matrix multiplication is associative.
104+ prop " (A*B)*C = A*(B*C)" $ \ (xs :: [Double ]) (ys :: [Double ]) (zs :: [Double ]) ->
105+ let a = mat4 xs; b = mat4 ys; c = mat4 zs
106+ in closeList (toList (mm (mm a b) c)) (toList (mm a (mm b c)))
107+
108+ -- Multiplication distributes over addition on the left.
109+ prop " A*(B+C) = A*B + A*C" $ \ (xs :: [Double ]) (ys :: [Double ]) (zs :: [Double ]) ->
110+ let a = mat4 xs; b = mat4 ys; c = mat4 zs
111+ in closeList (toList (mm a (b + c))) (toList (mm a b + mm a c))
112+
113+ -- Multiplication distributes over addition on the right.
114+ prop " (A+B)*C = A*C + B*C" $ \ (xs :: [Double ]) (ys :: [Double ]) (zs :: [Double ]) ->
115+ let a = mat4 xs; b = mat4 ys; c = mat4 zs
116+ in closeList (toList (mm (a + b) c)) (toList (mm a c + mm b c))
117+
118+ -- The identity is a left identity too (the existing case is right-sided).
119+ prop " I*A = A" $ \ (xs :: [Double ]) ->
120+ let a = mat4 xs
121+ in closeList (toList (mm (identity [4 ,4 ]) a)) (toList a)
122+
123+ -- Transpose of a product reverses the order of the factors.
124+ prop " (A*B)^T = B^T * A^T" $ \ (xs :: [Double ]) (ys :: [Double ]) ->
125+ let a = mat4 xs; b = mat4 ys
126+ in closeList (toList (tr (mm a b))) (toList (mm (tr b) (tr a)))
127+
128+ -- Transpose is additive.
129+ prop " (A+B)^T = A^T + B^T" $ \ (xs :: [Double ]) (ys :: [Double ]) ->
130+ let a = mat4 xs; b = mat4 ys
131+ in closeList (toList (tr (a + b))) (toList (tr a + tr b))
132+
133+ -- Scalar factors pull through a product: (cA)*B = c(A*B).
134+ prop " (cA)*B = c(A*B)" $ \ (c :: Double ) (xs :: [Double ]) (ys :: [Double ]) ->
135+ let a = mat4 xs; b = mat4 ys
136+ in closeList (toList (mm (scaleMat c a) b)) (toList (scaleMat c (mm a b)))
137+
138+ -- The zero matrix annihilates under multiplication.
139+ prop " A*0 = 0" $ \ (xs :: [Double ]) ->
140+ let a = mat4 xs
141+ in all (== 0 ) (toList (mm a (mat4 [] )))
142+
143+ -- gemm with alpha=1 and no transposition agrees with matmul.
144+ prop " gemm None None 1 A B = A*B" $ \ (xs :: [Double ]) (ys :: [Double ]) ->
145+ let a = mat4 xs; b = mat4 ys
146+ in closeList (toList (gemm None None 1.0 a b)) (toList (mm a b))
147+
148+ -- The dot product of real vectors is symmetric.
149+ prop " dot x y = dot y x" $ \ (xs :: [Double ]) (ys :: [Double ]) ->
150+ let x = vec4 xs; y = vec4 ys
151+ in closeList (toList (dot x y None None )) (toList (dot y x None None ))
0 commit comments