@@ -21,25 +21,27 @@ let promote_inner (exp : Ast.typed_expression) prom =
2121 { expr= Ast. Promotion (exp, UReal , AutoDiffable )
2222 ; emeta=
2323 { emeta with
24- type_= UnsizedType. promote_array emeta.type_ UReal
24+ type_= UnsizedType. promote_container emeta.type_ UReal
2525 ; ad_level= AutoDiffable } }
2626 | ToComplexVar ->
2727 Ast.
2828 { expr= Ast. Promotion (exp, UComplex , AutoDiffable )
2929 ; emeta=
3030 { emeta with
31- type_= UnsizedType. promote_array emeta.type_ UComplex
31+ type_= UnsizedType. promote_container emeta.type_ UComplex
3232 ; ad_level= AutoDiffable } }
3333 | IntToReal when UnsizedType. is_int_type emeta.type_ ->
3434 Ast.
3535 { expr= Ast. Promotion (exp, UReal , emeta.ad_level)
36- ; emeta= {emeta with type_= UnsizedType. promote_array emeta.type_ UReal }
36+ ; emeta=
37+ {emeta with type_= UnsizedType. promote_container emeta.type_ UReal }
3738 }
3839 | (IntToComplex | RealToComplex )
3940 when not (UnsizedType. is_complex_type emeta.type_) ->
4041 (* these two promotions are separated for cost, but are actually the same promotion *)
4142 { expr= Promotion (exp, UComplex , emeta.ad_level)
42- ; emeta= {emeta with type_= UnsizedType. promote_array emeta.type_ UComplex }
43+ ; emeta=
44+ {emeta with type_= UnsizedType. promote_container emeta.type_ UComplex }
4345 }
4446 | _ -> exp
4547
@@ -54,37 +56,54 @@ let rec promote (exp : Ast.typed_expression) prom =
5456 { expr= ArrayExpr pes
5557 ; emeta=
5658 { exp.emeta with
57- type_= UnsizedType. promote_array exp.emeta.type_ type_
59+ type_= UnsizedType. promote_container exp.emeta.type_ type_
5860 ; ad_level } }
5961 | RowVectorExpr (_ :: _ as es ) ->
6062 let pes = List. map ~f: (fun e -> promote e prom) es in
6163 let fst = List. hd_exn pes in
6264 let ad_level = fst.emeta.ad_level in
63- {expr= RowVectorExpr pes; emeta= {exp.emeta with ad_level}}
65+ let type_ =
66+ (* "RowVectorExpr" can also be a matrix expr, depends on what is inside *)
67+ match fst.emeta.type_ with
68+ | UComplexRowVector -> UnsizedType. UComplexMatrix
69+ | URowVector -> UMatrix
70+ | UComplex -> UComplexRowVector
71+ | _ -> URowVector in
72+ {expr= RowVectorExpr pes; emeta= {exp.emeta with type_; ad_level}}
6473 | _ -> promote_inner exp prom
6574
6675let promote_list es promotions = List. map2_exn es promotions ~f: promote
6776
6877(* * Get the promotion needed to make the second type into the first.
6978 Types NEED to have previously been checked to be promotable
7079*)
71- let rec get_type_promotion_exn (ad , ty ) (ad2 , ty2 ) =
72- match (ty, ty2 ) with
80+ let rec get_type_promotion_exn (ad_orig , ty_orig ) (ad_expect , ty_expect ) =
81+ match (ty_orig, ty_expect ) with
7382 | UnsizedType. (UReal , (UReal | UInt ) | UVector , UVector | UMatrix , UMatrix )
74- when ad <> ad2 ->
83+ when ad_orig <> ad_expect ->
7584 ToVar
76- | UComplex , (UReal | UInt | UComplex ) when ad <> ad2 -> ToComplexVar
85+ | UComplex , (UReal | UInt | UComplex )
86+ | UComplexMatrix , (UMatrix | UComplexMatrix )
87+ | UComplexVector , (UVector | UComplexVector )
88+ | UComplexRowVector , (URowVector | UComplexRowVector )
89+ when ad_orig <> ad_expect ->
90+ ToComplexVar
7791 | UReal , UInt -> IntToReal
7892 | UComplex , UInt -> IntToComplex
79- | UComplex , UReal -> RealToComplex
80- | UArray nt1 , UArray nt2 -> get_type_promotion_exn (ad, nt1) (ad2, nt2)
93+ | UComplex , UReal
94+ | UComplexMatrix , UMatrix
95+ | UComplexVector , UVector
96+ | UComplexRowVector , URowVector ->
97+ RealToComplex
98+ | UArray nt1 , UArray nt2 ->
99+ get_type_promotion_exn (ad_orig, nt1) (ad_expect, nt2)
81100 | t1 , t2 when t1 = t2 -> NoPromotion
82101 | _ , _ ->
83102 Common.FatalError. fatal_error_msg
84103 [% message
85104 " Tried to get promotion of mismatched types!"
86- (ty : UnsizedType.t )
87- (ty2 : UnsizedType.t )]
105+ (ty_orig : UnsizedType.t )
106+ (ty_expect : UnsizedType.t )]
88107
89108(* * Calculate the "cost"/number of promotions performed.
90109 Used to disambiguate function signatures
0 commit comments