@@ -66,6 +66,73 @@ let list_multi_tildes (mir : Program.Typed.t) :
6666 ~f: (fun ~key ~data s -> Set. add s (key, data))
6767 multi_tildes
6868
69+ (* * Collect statements of the form "target += Dist(param, ...)" where param
70+ has possibly been transformed non-linearly *)
71+ let list_possible_nonlinear (mir : Program.Typed.t ) : Location_span. t Set.Poly. t
72+ =
73+ (* These functions are linear if all of their arguments are *)
74+ let linear_fnames =
75+ Operator. (
76+ [Plus ; PPlus ; Minus ; PMinus ; PNot ; Transpose ] |> List. map ~f: to_string)
77+ @ [ " add" ; " append_block" ; " append_row" ; " append_col" ; " block" ; " col" ; " cols"
78+ ; " row" ; " rows" ; " diagonal" ; " head" ; " tail" ; " minus" ; " negative_infinity"
79+ ; " not_a_number" ; " rep_matrix" ; " rep_vector" ; " rep_row_vector"
80+ ; " positive_infinity" ; " segment" ; " subtract" ; " sum" ; " to_vector"
81+ ; " to_row_vector" ; " to_matrix" ; " to_array_1d" ; " to_array_2d" ; " transpose"
82+ ]
83+ |> String.Set. of_list in
84+ (* A simple check of linearity of an expression.
85+ allow_var is used for expressions like a*b, where at most one
86+ of a and b can be a variable
87+ *)
88+ let rec is_linear allow_var Expr.Fixed. {pattern; _} =
89+ match pattern with
90+ | Expr.Fixed.Pattern. Var _ -> allow_var
91+ | Lit _ -> true
92+ | Indexed (e , _ ) | Promotion (e , _ , _ ) -> is_linear allow_var e
93+ | TernaryIf (e1 , e2 , e3 ) ->
94+ is_linear allow_var e1 && is_linear allow_var e2
95+ && is_linear allow_var e3
96+ | FunApp (StanLib (name , _ , _ ), args ) ->
97+ is_linear_function allow_var name args
98+ | FunApp (CompilerInternal (FnMakeArray | FnMakeRowVec ), args ) ->
99+ List. for_all ~f: (is_linear allow_var) args
100+ | _ -> false
101+ and is_linear_function allow_var name (args : 'a Expr.Fixed.t list ) =
102+ match (name, args) with
103+ | _ , _ when Set. mem linear_fnames name ->
104+ List. for_all ~f: (is_linear allow_var) args
105+ | _ , _ when List. for_all ~f: (is_linear false ) args ->
106+ (* A function of all constants is fine *) true
107+ | ("Times__" | "Divide__" | "IntDivide__" ), [a; b] ->
108+ (* We require at least one of these operands to be a constant *)
109+ (is_linear allow_var a && is_linear false b)
110+ || (is_linear false a && is_linear allow_var b)
111+ | "fma" , [a; b; c] ->
112+ (* Similar to above.
113+ Partial evaluation can create fmas where the user wrote Times *)
114+ is_linear allow_var c
115+ && ( (is_linear allow_var a && is_linear false b)
116+ || (is_linear false a && is_linear allow_var b) )
117+ | _ -> false in
118+ let maybe_nonlinear_tilde (stmt : Stmt.Located.t ) =
119+ match stmt.pattern with
120+ (* a ~ foo(...) gets translated to target += foo_lpdf(a, ...) *)
121+ | Stmt.Fixed.Pattern. TargetPE
122+ { pattern=
123+ Expr.Fixed.Pattern. FunApp
124+ ((StanLib (_, FnLpdf _, _) | UserDefined (_, FnLpdf _)), e :: _)
125+ ; _ }
126+ when not (is_linear true e) ->
127+ Set.Poly. singleton stmt.meta
128+ | _ -> Set.Poly. empty in
129+ let bad_tildes =
130+ fold_stmts
131+ ~take_stmt: (fun m s -> Set.Poly. union m (maybe_nonlinear_tilde s))
132+ ~take_expr: (fun m _ -> m)
133+ ~init: Set.Poly. empty mir.log_prob in
134+ bad_tildes
135+
69136(* Find all of the targets which are dependencies for a given label *)
70137let var_deps info_map label ?expr :(expr_opt : Expr.Typed.t option = None )
71138 (targets : string Set.Poly.t ) : string Set.Poly. t =
@@ -130,7 +197,7 @@ let list_arg_dependant_fundef_cf (mir : Program.Typed.t)
130197 Option. value_exn
131198 ~message:
132199 " INTERNAL ERROR: Pedantic mode found CF dependent on an \
133- arg,but the arg is mismatched. Please report a bug.\n "
200+ arg, but the arg is mismatched. Please report a bug.\n "
134201 (List. findi args ~f: (fun _ arg -> arg = name)) in
135202 (loc, ix, name) ) ) )
136203
@@ -320,6 +387,17 @@ let hard_constrained_warnings (mir : Program.Typed.t) =
320387 (Location_span. empty, nonsense_constrained_message pname) )
321388 pnames
322389
390+ let maybe_jacobian_adjustment_warnings (mir : Program.Typed.t ) =
391+ let locations = list_possible_nonlinear mir in
392+ Set.Poly. map
393+ ~f: (fun loc ->
394+ ( loc
395+ , " Left-hand side of sampling statement (~) may contain a non-linear \
396+ transform of a parameter or local variable. If it does, you need to \
397+ include a target += statement with the log absolute determinant of \
398+ the Jacobian of the transform." ) )
399+ locations
400+
323401let multi_tildes_message (vname : string ) : string =
324402 Printf. sprintf
325403 " The parameter %s is on the left-hand side of more than one tilde \
@@ -421,9 +499,9 @@ let warn_pedantic (mir_unopt : Program.Typed.t) =
421499 let factor_graph = prog_factor_graph mir in
422500 Set.Poly. union_list
423501 [ uninitialized_warnings mir; unscaled_constants_warnings distributions_info
424- ; multi_tildes_warnings mir; hard_constrained_warnings mir
425- ; unused_params_warnings factor_graph mir; param_dependant_cf_warnings mir
426- ; param_dependant_fundef_cf_warnings mir
502+ ; multi_tildes_warnings mir; maybe_jacobian_adjustment_warnings mir
503+ ; hard_constrained_warnings mir; unused_params_warnings factor_graph mir
504+ ; param_dependant_cf_warnings mir; param_dependant_fundef_cf_warnings mir
427505 ; non_one_priors_warnings factor_graph mir
428506 ; distribution_warnings distributions_info ]
429507 |> to_list
0 commit comments