Skip to content

Commit c47e3f2

Browse files
authored
Merge pull request #1133 from WardBrian/complex-containers
Complex containers
2 parents d5ddaef + 96ca12d commit c47e3f2

60 files changed

Lines changed: 13078 additions & 6453 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/analysis_and_optimization/Debug_data_generation.ml

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,22 @@ let gen_matrix mm m n t =
232232
Expr.Helpers.matrix_from_rows
233233
(repeat_th m (fun () -> gen_row_vector mm n t))
234234

235+
let gen_complex_unwrapped () =
236+
( gen_num_real Map.Poly.empty Transformation.Identity
237+
, gen_num_real Map.Poly.empty Transformation.Identity )
238+
239+
let gen_complex () = Expr.Helpers.complex (gen_complex_unwrapped ())
240+
241+
let gen_complex_row_vector n =
242+
Expr.Helpers.complex_row_vector (repeat_th n gen_complex_unwrapped)
243+
244+
let gen_complex_vector n =
245+
Expr.Helpers.complex_vector (repeat_th n gen_complex_unwrapped)
246+
247+
let gen_complex_matrix m n =
248+
Expr.Helpers.complex_matrix_from_rows
249+
(repeat_th m (fun () -> gen_complex_row_vector n))
250+
235251
let rec gen_array m st n t =
236252
let elt () = generate_value m st t in
237253
match (t : Expr.Typed.t Transformation.t) with
@@ -252,11 +268,15 @@ and generate_value m st t =
252268
| SReal -> Expr.Helpers.float (gen_num_real m t)
253269
| SComplex ->
254270
(* when serialzied, a complex number looks just like a 2-array of reals *)
255-
generate_value m (SArray (SReal, Expr.Helpers.int 2)) t
271+
gen_complex ()
256272
| SVector (_, e) -> gen_vector m (unwrap_int_exn m e) t
257273
| SRowVector (_, e) -> gen_row_vector m (unwrap_int_exn m e) t
258274
| SMatrix (_, e1, e2) ->
259275
gen_matrix m (unwrap_int_exn m e1) (unwrap_int_exn m e2) t
276+
| SComplexVector e -> gen_complex_vector (unwrap_int_exn m e)
277+
| SComplexRowVector e -> gen_complex_row_vector (unwrap_int_exn m e)
278+
| SComplexMatrix (e1, e2) ->
279+
gen_complex_matrix (unwrap_int_exn m e1) (unwrap_int_exn m e2)
260280
| SArray (st, e) -> gen_array m st (unwrap_int_exn m e) t
261281

262282
let rec pp_value_json ppf e =
@@ -267,6 +287,8 @@ let rec pp_value_json ppf e =
267287
| FunApp (StanLib (transpose, _, _), [e])
268288
when String.equal transpose (Operator.to_string Transpose) ->
269289
pp_value_json ppf e
290+
| FunApp (StanLib ("to_complex", _, _), [r; i]) ->
291+
Fmt.(pf ppf "[@[<hov 1>%a, %a@]]" pp_value_json r pp_value_json i)
270292
| _ ->
271293
Common.FatalError.fatal_error_msg
272294
[%message "Could not evaluate expression " (e : Expr.Typed.t)]

src/analysis_and_optimization/Mem_pattern.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,9 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
425425
Set.Poly.union_list
426426
[ acc; query_expr acc predicate
427427
; query_initial_demotable_stmt true acc body ]
428+
| Decl {decl_type= Type.Sized st; decl_id; _}
429+
when SizedType.is_complex_type st ->
430+
Set.Poly.add acc decl_id
428431
| Skip | Break | Continue | Decl _ -> acc
429432

430433
(** Look through a statement to see whether the objects used in it need to be

src/frontend/Ast.ml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,13 @@ let get_loc_dt (t : untyped_expression Type.t) =
329329
match t with
330330
| Type.Unsized _ | Sized (SInt | SReal | SComplex) -> None
331331
| Sized
332-
(SVector (_, e) | SRowVector (_, e) | SMatrix (_, e, _) | SArray (_, e))
333-
->
332+
( SVector (_, e)
333+
| SRowVector (_, e)
334+
| SMatrix (_, e, _)
335+
| SComplexVector e
336+
| SComplexRowVector e
337+
| SComplexMatrix (e, _)
338+
| SArray (_, e) ) ->
334339
Some e.emeta.loc.begin_loc
335340

336341
let get_loc_tf (t : untyped_expression Transformation.t) =

src/frontend/Ast_to_Mir.ml

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ let param_size transform sizedtype =
231231
| SizedType.SArray (t, d) -> SizedType.SArray (shrink_eigen f t, d)
232232
| SVector (mem_pattern, d) | SMatrix (mem_pattern, d, _) ->
233233
SVector (mem_pattern, f d)
234-
| SInt | SReal | SComplex | SRowVector _ ->
234+
| SInt | SReal | SComplex | SRowVector _ | SComplexRowVector _
235+
|SComplexVector _ | SComplexMatrix _ ->
235236
Common.FatalError.fatal_error_msg
236237
[%message
237238
"Expecting SVector or SMatrix, got " (st : Expr.Typed.t SizedType.t)]
@@ -240,7 +241,8 @@ let param_size transform sizedtype =
240241
match st with
241242
| SizedType.SArray (t, d) -> SizedType.SArray (shrink_eigen_mat f t, d)
242243
| SMatrix (mem_pattern, d1, d2) -> SVector (mem_pattern, f d1 d2)
243-
| SInt | SReal | SComplex | SRowVector _ | SVector _ ->
244+
| SInt | SReal | SComplex | SRowVector _ | SVector _ | SComplexRowVector _
245+
|SComplexVector _ | SComplexMatrix _ ->
244246
Common.FatalError.fatal_error_msg
245247
[%message "Expecting SMatrix, got " (st : Expr.Typed.t SizedType.t)]
246248
in
@@ -308,6 +310,16 @@ let check_sizedtype name =
308310
let er = trans_expr r in
309311
let ec = trans_expr c in
310312
(check r er @ check c ec, SizedType.SMatrix (mem_pattern, er, ec))
313+
| SComplexVector s ->
314+
let e = trans_expr s in
315+
(check s e, SizedType.SComplexVector e)
316+
| SComplexRowVector s ->
317+
let e = trans_expr s in
318+
(check s e, SizedType.SComplexRowVector e)
319+
| SComplexMatrix (r, c) ->
320+
let er = trans_expr r in
321+
let ec = trans_expr c in
322+
(check r er @ check c ec, SizedType.SComplexMatrix (er, ec))
311323
| SArray (t, s) ->
312324
let e = trans_expr s in
313325
let ll, t = sizedtype t in
@@ -581,6 +593,12 @@ let trans_sizedtype_decl declc tr name =
581593
| SRowVector (mem_pattern, s) ->
582594
let l, s = grab_size FnValidateSize n s in
583595
(l, SizedType.SRowVector (mem_pattern, s))
596+
| SComplexRowVector s ->
597+
let l, s = grab_size FnValidateSize n s in
598+
(l, SizedType.SComplexRowVector s)
599+
| SComplexVector s ->
600+
let l, s = grab_size FnValidateSize n s in
601+
(l, SizedType.SComplexVector s)
584602
| SMatrix (mem_pattern, r, c) ->
585603
let l1, r = grab_size FnValidateSize n r in
586604
let l2, c = grab_size FnValidateSize (n + 1) c in
@@ -598,6 +616,10 @@ let trans_sizedtype_decl declc tr name =
598616
; meta= r.Expr.Fixed.meta.Expr.Typed.Meta.loc } ]
599617
| _ -> [] in
600618
(l1 @ l2 @ cf_cov, SizedType.SMatrix (mem_pattern, r, c))
619+
| SComplexMatrix (r, c) ->
620+
let l1, r = grab_size FnValidateSize n r in
621+
let l2, c = grab_size FnValidateSize (n + 1) c in
622+
(l1 @ l2, SizedType.SComplexMatrix (r, c))
601623
| SArray (t, s) ->
602624
let l, s = grab_size FnValidateSize n s in
603625
let ll, t = go (n + 1) t in

src/frontend/Info.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ let rec sized_basetype_dims t =
88
| SReal -> ("real", 0)
99
| SComplex -> ("complex", 0)
1010
| SVector _ | SRowVector _ -> ("real", 1)
11+
| SComplexVector _ | SComplexRowVector _ -> ("complex", 1)
1112
| SMatrix _ -> ("real", 2)
13+
| SComplexMatrix _ -> ("complex", 2)
1214
| SArray (t, _) ->
1315
let bt, n = sized_basetype_dims t in
1416
(bt, n + 1)
@@ -19,7 +21,9 @@ let rec unsized_basetype_dims t =
1921
| UReal -> ("real", 0)
2022
| UComplex -> ("complex", 0)
2123
| UVector | URowVector -> ("real", 1)
24+
| UComplexVector | UComplexRowVector -> ("complex", 1)
2225
| UMatrix -> ("real", 2)
26+
| UComplexMatrix -> ("complex", 2)
2327
| UArray t ->
2428
let bt, n = unsized_basetype_dims t in
2529
(bt, n + 1)

src/frontend/Pretty_printing.ml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ let pp_autodifftype = Middle.UnsizedType.pp_autodifftype
180180
let rec unwind_sized_array_type st =
181181
match st with
182182
| Middle.SizedType.SInt | SReal | SComplex | SVector _ | SRowVector _
183-
|SMatrix _ ->
183+
|SMatrix _ | SComplexMatrix _ | SComplexVector _ | SComplexRowVector _ ->
184184
(st, [])
185185
| SArray (st, dim) ->
186186
let st', dims = unwind_sized_array_type st in
@@ -319,6 +319,10 @@ let rec pp_sizedtype ppf = function
319319
| SRowVector (_, e) -> pf ppf "row_vector[%a]" pp_expression e
320320
| SMatrix (_, e1, e2) ->
321321
pf ppf "matrix[%a, %a]" pp_expression e1 pp_expression e2
322+
| SComplexVector e -> pf ppf "complex_vector[%a]" pp_expression e
323+
| SComplexRowVector e -> pf ppf "complex_row_vector[%a]" pp_expression e
324+
| SComplexMatrix (e1, e2) ->
325+
pf ppf "complex_matrix[%a, %a]" pp_expression e1 pp_expression e2
322326
| SArray _ as arr ->
323327
let ty, ixs = unwind_sized_array_type arr in
324328
pf ppf "array[@[%a@]]@ %a"
@@ -348,9 +352,12 @@ let pp_transformed_type ppf (pst, trans) =
348352
let pp_possibly_transformed_type ppf (st, trans) =
349353
let sizes_fmt =
350354
match st with
351-
| SizedType.SVector (_, e) | SRowVector (_, e) ->
355+
| SizedType.SVector (_, e)
356+
|SRowVector (_, e)
357+
|SComplexVector e
358+
|SComplexRowVector e ->
352359
const (fun ppf -> pf ppf "[%a]" pp_expression) e
353-
| SMatrix (_, e1, e2) ->
360+
| SMatrix (_, e1, e2) | SComplexMatrix (e1, e2) ->
354361
const
355362
(fun ppf -> pf ppf "[%a, %a]" pp_expression e1 pp_expression)
356363
e2

src/frontend/Promotion.ml

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,27 @@ let promote_inner (exp : Ast.typed_expression) prom =
2121
{ expr= Ast.Promotion (exp, UReal, AutoDiffable)
2222
; emeta=
2323
{ emeta with
24-
type_= UnsizedType.promote_array emeta.type_ UReal
24+
type_= UnsizedType.promote_container emeta.type_ UReal
2525
; ad_level= AutoDiffable } }
2626
| ToComplexVar ->
2727
Ast.
2828
{ expr= Ast.Promotion (exp, UComplex, AutoDiffable)
2929
; emeta=
3030
{ emeta with
31-
type_= UnsizedType.promote_array emeta.type_ UComplex
31+
type_= UnsizedType.promote_container emeta.type_ UComplex
3232
; ad_level= AutoDiffable } }
3333
| IntToReal when UnsizedType.is_int_type emeta.type_ ->
3434
Ast.
3535
{ expr= Ast.Promotion (exp, UReal, emeta.ad_level)
36-
; emeta= {emeta with type_= UnsizedType.promote_array emeta.type_ UReal}
36+
; emeta=
37+
{emeta with type_= UnsizedType.promote_container emeta.type_ UReal}
3738
}
3839
| (IntToComplex | RealToComplex)
3940
when not (UnsizedType.is_complex_type emeta.type_) ->
4041
(* these two promotions are separated for cost, but are actually the same promotion *)
4142
{ expr= Promotion (exp, UComplex, emeta.ad_level)
42-
; emeta= {emeta with type_= UnsizedType.promote_array emeta.type_ UComplex}
43+
; emeta=
44+
{emeta with type_= UnsizedType.promote_container emeta.type_ UComplex}
4345
}
4446
| _ -> exp
4547

@@ -54,37 +56,54 @@ let rec promote (exp : Ast.typed_expression) prom =
5456
{ expr= ArrayExpr pes
5557
; emeta=
5658
{ exp.emeta with
57-
type_= UnsizedType.promote_array exp.emeta.type_ type_
59+
type_= UnsizedType.promote_container exp.emeta.type_ type_
5860
; ad_level } }
5961
| RowVectorExpr (_ :: _ as es) ->
6062
let pes = List.map ~f:(fun e -> promote e prom) es in
6163
let fst = List.hd_exn pes in
6264
let ad_level = fst.emeta.ad_level in
63-
{expr= RowVectorExpr pes; emeta= {exp.emeta with ad_level}}
65+
let type_ =
66+
(* "RowVectorExpr" can also be a matrix expr, depends on what is inside *)
67+
match fst.emeta.type_ with
68+
| UComplexRowVector -> UnsizedType.UComplexMatrix
69+
| URowVector -> UMatrix
70+
| UComplex -> UComplexRowVector
71+
| _ -> URowVector in
72+
{expr= RowVectorExpr pes; emeta= {exp.emeta with type_; ad_level}}
6473
| _ -> promote_inner exp prom
6574

6675
let promote_list es promotions = List.map2_exn es promotions ~f:promote
6776

6877
(** Get the promotion needed to make the second type into the first.
6978
Types NEED to have previously been checked to be promotable
7079
*)
71-
let rec get_type_promotion_exn (ad, ty) (ad2, ty2) =
72-
match (ty, ty2) with
80+
let rec get_type_promotion_exn (ad_orig, ty_orig) (ad_expect, ty_expect) =
81+
match (ty_orig, ty_expect) with
7382
| UnsizedType.(UReal, (UReal | UInt) | UVector, UVector | UMatrix, UMatrix)
74-
when ad <> ad2 ->
83+
when ad_orig <> ad_expect ->
7584
ToVar
76-
| UComplex, (UReal | UInt | UComplex) when ad <> ad2 -> ToComplexVar
85+
| UComplex, (UReal | UInt | UComplex)
86+
|UComplexMatrix, (UMatrix | UComplexMatrix)
87+
|UComplexVector, (UVector | UComplexVector)
88+
|UComplexRowVector, (URowVector | UComplexRowVector)
89+
when ad_orig <> ad_expect ->
90+
ToComplexVar
7791
| UReal, UInt -> IntToReal
7892
| UComplex, UInt -> IntToComplex
79-
| UComplex, UReal -> RealToComplex
80-
| UArray nt1, UArray nt2 -> get_type_promotion_exn (ad, nt1) (ad2, nt2)
93+
| UComplex, UReal
94+
|UComplexMatrix, UMatrix
95+
|UComplexVector, UVector
96+
|UComplexRowVector, URowVector ->
97+
RealToComplex
98+
| UArray nt1, UArray nt2 ->
99+
get_type_promotion_exn (ad_orig, nt1) (ad_expect, nt2)
81100
| t1, t2 when t1 = t2 -> NoPromotion
82101
| _, _ ->
83102
Common.FatalError.fatal_error_msg
84103
[%message
85104
"Tried to get promotion of mismatched types!"
86-
(ty : UnsizedType.t)
87-
(ty2 : UnsizedType.t)]
105+
(ty_orig : UnsizedType.t)
106+
(ty_expect : UnsizedType.t)]
88107

89108
(** Calculate the "cost"/number of promotions performed.
90109
Used to disambiguate function signatures

src/frontend/SignatureMismatch.ml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ let pp_unsized_type ctx ppf =
1717
let rec pp ppf ty =
1818
match ty with
1919
| UnsizedType.UInt | UReal | UVector | URowVector | UMatrix | UComplex
20-
|UMathLibraryFunction ->
20+
|UComplexRowVector | UComplexVector | UComplexMatrix | UMathLibraryFunction
21+
->
2122
UnsizedType.pp ppf ty
2223
| UArray ut ->
2324
let ut2, d = UnsizedType.unwind_array_type ut in
@@ -140,8 +141,13 @@ let rec check_same_type depth t1 t2 =
140141
match (t1, t2) with
141142
| t1, t2 when t1 = t2 -> Ok Promotion.NoPromotion
142143
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok IntToReal
143-
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok IntToComplex
144-
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok RealToComplex
144+
| UComplex, UInt when depth < 1 -> Ok IntToComplex
145+
| UComplex, UReal
146+
|UComplexMatrix, UMatrix
147+
|UComplexVector, UVector
148+
|UComplexRowVector, URowVector
149+
when depth < 1 ->
150+
Ok RealToComplex
145151
(* Arrays: Try to recursively promote, but make sure the error is for these types,
146152
not the recursive call *)
147153
| UArray nt1, UArray nt2 ->

0 commit comments

Comments
 (0)