@@ -66,52 +66,57 @@ let list_multi_tildes (mir : Program.Typed.t) :
6666 ~f: (fun ~key ~data s -> Set. add s (key, data))
6767 multi_tildes
6868
69- (* * These functions are linear if all of their arguments are *)
70- let linear_fnames =
71- Operator. ([Plus ; PPlus ; Minus ; PMinus ] |> List. map ~f: to_string)
72- @ [ " add" ; " append_block" ; " append_row" ; " append_col" ; " block" ; " col" ; " cols"
73- ; " row" ; " rows" ; " diagonal" ; " head" ; " tail" ; " minus" ; " negative_infinity"
74- ; " not_a_number" ; " rep_matrix" ; " rep_vector" ; " rep_row_vector"
75- ; " positive_infinity" ; " segment" ; " subtract" ; " sum" ; " to_vector"
76- ; " to_row_vector" ; " to_matrix" ; " to_array_1d" ; " to_array_2d" ; " transpose"
77- ; " Plus__" ; " PPlus__" ; " PMinus__" ; " Minus__" ; " PNot__" ; " Transpose__" ]
78- |> String.Set. of_list
79-
69+ (* * Collect statements of the form "target += Dist(param, ...)" where param
70+ has possibly been transformed non-linearly *)
8071let list_possible_nonlinear (mir : Program.Typed.t ) : Location_span. t Set.Poly. t
8172 =
82- (* Collect statements of the form "target += Dist(param, ...)" *)
83- let rec is_linear_function allow_var name (args : 'a Expr.Fixed.t list ) =
73+ (* These functions are linear if all of their arguments are *)
74+ let linear_fnames =
75+ Operator. ([Plus ; PPlus ; Minus ; PMinus ] |> List. map ~f: to_string)
76+ @ [ " add" ; " append_block" ; " append_row" ; " append_col" ; " block" ; " col" ; " cols"
77+ ; " row" ; " rows" ; " diagonal" ; " head" ; " tail" ; " minus" ; " negative_infinity"
78+ ; " not_a_number" ; " rep_matrix" ; " rep_vector" ; " rep_row_vector"
79+ ; " positive_infinity" ; " segment" ; " subtract" ; " sum" ; " to_vector"
80+ ; " to_row_vector" ; " to_matrix" ; " to_array_1d" ; " to_array_2d" ; " transpose"
81+ ; " Plus__" ; " PPlus__" ; " PMinus__" ; " Minus__" ; " PNot__" ; " Transpose__" ]
82+ |> String.Set. of_list in
83+ (* A simple check of linearity of an expression.
84+ allow_var is used for expressions like a*b, where at most one
85+ of a and b can be a variable
86+ *)
87+ let rec is_linear allow_var Expr.Fixed. {pattern; _} =
88+ match pattern with
89+ | Expr.Fixed.Pattern. Var _ -> allow_var
90+ | Lit _ -> true
91+ | Indexed (e , _ ) | Promotion (e , _ , _ ) -> is_linear allow_var e
92+ | TernaryIf (e1 , e2 , e3 ) ->
93+ is_linear allow_var e1 && is_linear allow_var e2
94+ && is_linear allow_var e3
95+ | FunApp (StanLib (name , _ , _ ), args ) ->
96+ is_linear_function allow_var name args
97+ | FunApp (CompilerInternal (FnMakeArray | FnMakeRowVec ), args ) ->
98+ List. for_all ~f: (is_linear allow_var) args
99+ | _ -> false
100+ and is_linear_function allow_var name (args : 'a Expr.Fixed.t list ) =
84101 match (name, args) with
85102 | _ , _ when Set. mem linear_fnames name ->
86103 List. for_all ~f: (is_linear allow_var) args
87104 | _ , _ when List. for_all ~f: (is_linear false ) args ->
88105 (* A function of all constants is fine *) true
89- (* We require at least one of these operands to be a constant *)
90106 | ("Times__" | "Divide__" | "IntDivide__" ), [a; b] ->
107+ (* We require at least one of these operands to be a constant *)
91108 (is_linear allow_var a && is_linear false b)
92109 || (is_linear false a && is_linear allow_var b)
93- (* Partial evaluation can create fmas where the user wrote Times*)
94110 | "fma" , [a; b; c] ->
111+ (* Similar to above.
112+ Partial evaluation can create fmas where the user wrote Times *)
95113 is_linear allow_var c
96114 && ( (is_linear allow_var a && is_linear false b)
97115 || (is_linear false a && is_linear allow_var b) )
98- | _ -> false
99- and is_linear allow_var Expr.Fixed. {pattern; _} =
100- match pattern with
101- | Expr.Fixed.Pattern. Var _ -> allow_var
102- | Lit _ -> true
103- | Indexed (e , _ ) | Promotion (e , _ , _ ) -> is_linear allow_var e
104- | TernaryIf (e1 , e2 , e3 ) ->
105- is_linear allow_var e1 && is_linear allow_var e2
106- && is_linear allow_var e3
107- | FunApp (StanLib (name , _ , _ ), args ) ->
108- is_linear_function allow_var name args
109- | FunApp (CompilerInternal (FnMakeArray | FnMakeRowVec ), args ) ->
110- List. for_all ~f: (is_linear allow_var) args
111116 | _ -> false in
112- let collect_transformed_tilde_stmt (stmt : Stmt.Located.t ) :
113- Location_span. t Set.Poly. t =
117+ let maybe_nonlinear_tilde (stmt : Stmt.Located.t ) =
114118 match stmt.pattern with
119+ (* a ~ foo(...) gets translated to target += foo_lpdf(a, ...) *)
115120 | Stmt.Fixed.Pattern. TargetPE
116121 { pattern=
117122 Expr.Fixed.Pattern. FunApp
@@ -120,12 +125,12 @@ let list_possible_nonlinear (mir : Program.Typed.t) : Location_span.t Set.Poly.t
120125 when not (is_linear true e) ->
121126 Set.Poly. singleton stmt.meta
122127 | _ -> Set.Poly. empty in
123- let tildes =
128+ let bad_tildes =
124129 fold_stmts
125- ~take_stmt: (fun m s -> Set.Poly. union m (collect_transformed_tilde_stmt s))
130+ ~take_stmt: (fun m s -> Set.Poly. union m (maybe_nonlinear_tilde s))
126131 ~take_expr: (fun m _ -> m)
127132 ~init: Set.Poly. empty mir.log_prob in
128- tildes
133+ bad_tildes
129134
130135(* Find all of the targets which are dependencies for a given label *)
131136let var_deps info_map label ?expr :(expr_opt : Expr.Typed.t option = None )
0 commit comments