Skip to content

Commit d9b9838

Browse files
committed
Comments
1 parent 0f352a6 commit d9b9838

1 file changed

Lines changed: 38 additions & 33 deletions

File tree

src/analysis_and_optimization/Pedantic_analysis.ml

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -66,52 +66,57 @@ let list_multi_tildes (mir : Program.Typed.t) :
6666
~f:(fun ~key ~data s -> Set.add s (key, data))
6767
multi_tildes
6868

69-
(** These functions are linear if all of their arguments are *)
70-
let linear_fnames =
71-
Operator.([Plus; PPlus; Minus; PMinus] |> List.map ~f:to_string)
72-
@ [ "add"; "append_block"; "append_row"; "append_col"; "block"; "col"; "cols"
73-
; "row"; "rows"; "diagonal"; "head"; "tail"; "minus"; "negative_infinity"
74-
; "not_a_number"; "rep_matrix"; "rep_vector"; "rep_row_vector"
75-
; "positive_infinity"; "segment"; "subtract"; "sum"; "to_vector"
76-
; "to_row_vector"; "to_matrix"; "to_array_1d"; "to_array_2d"; "transpose"
77-
; "Plus__"; "PPlus__"; "PMinus__"; "Minus__"; "PNot__"; "Transpose__" ]
78-
|> String.Set.of_list
79-
69+
(** Collect statements of the form "target += Dist(param, ...)" where param
70+
has possibly been transformed non-linearly *)
8071
let list_possible_nonlinear (mir : Program.Typed.t) : Location_span.t Set.Poly.t
8172
=
82-
(* Collect statements of the form "target += Dist(param, ...)" *)
83-
let rec is_linear_function allow_var name (args : 'a Expr.Fixed.t list) =
73+
(* These functions are linear if all of their arguments are *)
74+
let linear_fnames =
75+
Operator.([Plus; PPlus; Minus; PMinus] |> List.map ~f:to_string)
76+
@ [ "add"; "append_block"; "append_row"; "append_col"; "block"; "col"; "cols"
77+
; "row"; "rows"; "diagonal"; "head"; "tail"; "minus"; "negative_infinity"
78+
; "not_a_number"; "rep_matrix"; "rep_vector"; "rep_row_vector"
79+
; "positive_infinity"; "segment"; "subtract"; "sum"; "to_vector"
80+
; "to_row_vector"; "to_matrix"; "to_array_1d"; "to_array_2d"; "transpose"
81+
; "Plus__"; "PPlus__"; "PMinus__"; "Minus__"; "PNot__"; "Transpose__" ]
82+
|> String.Set.of_list in
83+
(* A simple check of linearity of an expression.
84+
allow_var is used for expressions like a*b, where at most one
85+
of a and b can be a variable
86+
*)
87+
let rec is_linear allow_var Expr.Fixed.{pattern; _} =
88+
match pattern with
89+
| Expr.Fixed.Pattern.Var _ -> allow_var
90+
| Lit _ -> true
91+
| Indexed (e, _) | Promotion (e, _, _) -> is_linear allow_var e
92+
| TernaryIf (e1, e2, e3) ->
93+
is_linear allow_var e1 && is_linear allow_var e2
94+
&& is_linear allow_var e3
95+
| FunApp (StanLib (name, _, _), args) ->
96+
is_linear_function allow_var name args
97+
| FunApp (CompilerInternal (FnMakeArray | FnMakeRowVec), args) ->
98+
List.for_all ~f:(is_linear allow_var) args
99+
| _ -> false
100+
and is_linear_function allow_var name (args : 'a Expr.Fixed.t list) =
84101
match (name, args) with
85102
| _, _ when Set.mem linear_fnames name ->
86103
List.for_all ~f:(is_linear allow_var) args
87104
| _, _ when List.for_all ~f:(is_linear false) args ->
88105
(* A function of all constants is fine *) true
89-
(* We require at least one of these operands to be a constant *)
90106
| ("Times__" | "Divide__" | "IntDivide__"), [a; b] ->
107+
(* We require at least one of these operands to be a constant *)
91108
(is_linear allow_var a && is_linear false b)
92109
|| (is_linear false a && is_linear allow_var b)
93-
(* Partial evaluation can create fmas where the user wrote Times*)
94110
| "fma", [a; b; c] ->
111+
(* Similar to above.
112+
Partial evaluation can create fmas where the user wrote Times *)
95113
is_linear allow_var c
96114
&& ( (is_linear allow_var a && is_linear false b)
97115
|| (is_linear false a && is_linear allow_var b) )
98-
| _ -> false
99-
and is_linear allow_var Expr.Fixed.{pattern; _} =
100-
match pattern with
101-
| Expr.Fixed.Pattern.Var _ -> allow_var
102-
| Lit _ -> true
103-
| Indexed (e, _) | Promotion (e, _, _) -> is_linear allow_var e
104-
| TernaryIf (e1, e2, e3) ->
105-
is_linear allow_var e1 && is_linear allow_var e2
106-
&& is_linear allow_var e3
107-
| FunApp (StanLib (name, _, _), args) ->
108-
is_linear_function allow_var name args
109-
| FunApp (CompilerInternal (FnMakeArray | FnMakeRowVec), args) ->
110-
List.for_all ~f:(is_linear allow_var) args
111116
| _ -> false in
112-
let collect_transformed_tilde_stmt (stmt : Stmt.Located.t) :
113-
Location_span.t Set.Poly.t =
117+
let maybe_nonlinear_tilde (stmt : Stmt.Located.t) =
114118
match stmt.pattern with
119+
(* a ~ foo(...) gets translated to target += foo_lpdf(a, ...) *)
115120
| Stmt.Fixed.Pattern.TargetPE
116121
{ pattern=
117122
Expr.Fixed.Pattern.FunApp
@@ -120,12 +125,12 @@ let list_possible_nonlinear (mir : Program.Typed.t) : Location_span.t Set.Poly.t
120125
when not (is_linear true e) ->
121126
Set.Poly.singleton stmt.meta
122127
| _ -> Set.Poly.empty in
123-
let tildes =
128+
let bad_tildes =
124129
fold_stmts
125-
~take_stmt:(fun m s -> Set.Poly.union m (collect_transformed_tilde_stmt s))
130+
~take_stmt:(fun m s -> Set.Poly.union m (maybe_nonlinear_tilde s))
126131
~take_expr:(fun m _ -> m)
127132
~init:Set.Poly.empty mir.log_prob in
128-
tildes
133+
bad_tildes
129134

130135
(* Find all of the targets which are dependencies for a given label *)
131136
let var_deps info_map label ?expr:(expr_opt : Expr.Typed.t option = None)

0 commit comments

Comments
 (0)