Skip to content

Commit ba71588

Browse files
authored
Merge pull request #1075 from WardBrian/jacobian-warning
Add Jacobian adjustment warning to pedantic mode
2 parents 19be33f + 2ecca82 commit ba71588

13 files changed

Lines changed: 223 additions & 66 deletions

src/analysis_and_optimization/Pedantic_analysis.ml

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,73 @@ let list_multi_tildes (mir : Program.Typed.t) :
6666
~f:(fun ~key ~data s -> Set.add s (key, data))
6767
multi_tildes
6868

69+
(** Collect statements of the form "target += Dist(param, ...)" where param
70+
has possibly been transformed non-linearly *)
71+
let list_possible_nonlinear (mir : Program.Typed.t) : Location_span.t Set.Poly.t
72+
=
73+
(* These functions are linear if all of their arguments are *)
74+
let linear_fnames =
75+
Operator.(
76+
[Plus; PPlus; Minus; PMinus; PNot; Transpose] |> List.map ~f:to_string)
77+
@ [ "add"; "append_block"; "append_row"; "append_col"; "block"; "col"; "cols"
78+
; "row"; "rows"; "diagonal"; "head"; "tail"; "minus"; "negative_infinity"
79+
; "not_a_number"; "rep_matrix"; "rep_vector"; "rep_row_vector"
80+
; "positive_infinity"; "segment"; "subtract"; "sum"; "to_vector"
81+
; "to_row_vector"; "to_matrix"; "to_array_1d"; "to_array_2d"; "transpose"
82+
]
83+
|> String.Set.of_list in
84+
(* A simple check of linearity of an expression.
85+
allow_var is used for expressions like a*b, where at most one
86+
of a and b can be a variable
87+
*)
88+
let rec is_linear allow_var Expr.Fixed.{pattern; _} =
89+
match pattern with
90+
| Expr.Fixed.Pattern.Var _ -> allow_var
91+
| Lit _ -> true
92+
| Indexed (e, _) | Promotion (e, _, _) -> is_linear allow_var e
93+
| TernaryIf (e1, e2, e3) ->
94+
is_linear allow_var e1 && is_linear allow_var e2
95+
&& is_linear allow_var e3
96+
| FunApp (StanLib (name, _, _), args) ->
97+
is_linear_function allow_var name args
98+
| FunApp (CompilerInternal (FnMakeArray | FnMakeRowVec), args) ->
99+
List.for_all ~f:(is_linear allow_var) args
100+
| _ -> false
101+
and is_linear_function allow_var name (args : 'a Expr.Fixed.t list) =
102+
match (name, args) with
103+
| _, _ when Set.mem linear_fnames name ->
104+
List.for_all ~f:(is_linear allow_var) args
105+
| _, _ when List.for_all ~f:(is_linear false) args ->
106+
(* A function of all constants is fine *) true
107+
| ("Times__" | "Divide__" | "IntDivide__"), [a; b] ->
108+
(* We require at least one of these operands to be a constant *)
109+
(is_linear allow_var a && is_linear false b)
110+
|| (is_linear false a && is_linear allow_var b)
111+
| "fma", [a; b; c] ->
112+
(* Similar to above.
113+
Partial evaluation can create fmas where the user wrote Times *)
114+
is_linear allow_var c
115+
&& ( (is_linear allow_var a && is_linear false b)
116+
|| (is_linear false a && is_linear allow_var b) )
117+
| _ -> false in
118+
let maybe_nonlinear_tilde (stmt : Stmt.Located.t) =
119+
match stmt.pattern with
120+
(* a ~ foo(...) gets translated to target += foo_lpdf(a, ...) *)
121+
| Stmt.Fixed.Pattern.TargetPE
122+
{ pattern=
123+
Expr.Fixed.Pattern.FunApp
124+
((StanLib (_, FnLpdf _, _) | UserDefined (_, FnLpdf _)), e :: _)
125+
; _ }
126+
when not (is_linear true e) ->
127+
Set.Poly.singleton stmt.meta
128+
| _ -> Set.Poly.empty in
129+
let bad_tildes =
130+
fold_stmts
131+
~take_stmt:(fun m s -> Set.Poly.union m (maybe_nonlinear_tilde s))
132+
~take_expr:(fun m _ -> m)
133+
~init:Set.Poly.empty mir.log_prob in
134+
bad_tildes
135+
69136
(* Find all of the targets which are dependencies for a given label *)
70137
let var_deps info_map label ?expr:(expr_opt : Expr.Typed.t option = None)
71138
(targets : string Set.Poly.t) : string Set.Poly.t =
@@ -130,7 +197,7 @@ let list_arg_dependant_fundef_cf (mir : Program.Typed.t)
130197
Option.value_exn
131198
~message:
132199
"INTERNAL ERROR: Pedantic mode found CF dependent on an \
133-
arg,but the arg is mismatched. Please report a bug.\n"
200+
arg, but the arg is mismatched. Please report a bug.\n"
134201
(List.findi args ~f:(fun _ arg -> arg = name)) in
135202
(loc, ix, name) ) ) )
136203

@@ -320,6 +387,17 @@ let hard_constrained_warnings (mir : Program.Typed.t) =
320387
(Location_span.empty, nonsense_constrained_message pname) )
321388
pnames
322389

390+
let maybe_jacobian_adjustment_warnings (mir : Program.Typed.t) =
391+
let locations = list_possible_nonlinear mir in
392+
Set.Poly.map
393+
~f:(fun loc ->
394+
( loc
395+
, "Left-hand side of sampling statement (~) may contain a non-linear \
396+
transform of a parameter or local variable. If it does, you need to \
397+
include a target += statement with the log absolute determinant of \
398+
the Jacobian of the transform." ) )
399+
locations
400+
323401
let multi_tildes_message (vname : string) : string =
324402
Printf.sprintf
325403
"The parameter %s is on the left-hand side of more than one tilde \
@@ -421,9 +499,9 @@ let warn_pedantic (mir_unopt : Program.Typed.t) =
421499
let factor_graph = prog_factor_graph mir in
422500
Set.Poly.union_list
423501
[ uninitialized_warnings mir; unscaled_constants_warnings distributions_info
424-
; multi_tildes_warnings mir; hard_constrained_warnings mir
425-
; unused_params_warnings factor_graph mir; param_dependant_cf_warnings mir
426-
; param_dependant_fundef_cf_warnings mir
502+
; multi_tildes_warnings mir; maybe_jacobian_adjustment_warnings mir
503+
; hard_constrained_warnings mir; unused_params_warnings factor_graph mir
504+
; param_dependant_cf_warnings mir; param_dependant_fundef_cf_warnings mir
427505
; non_one_priors_warnings factor_graph mir
428506
; distribution_warnings distributions_info ]
429507
|> to_list
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
data {
2+
int N;
3+
real x;
4+
}
5+
parameters {
6+
vector[3] v;
7+
array[4,5] real a;
8+
cholesky_factor_corr[6] m;
9+
real y;
10+
real z;
11+
}
12+
model {
13+
// int_literal
14+
1 ~ normal(y,1);
15+
16+
// double_literal
17+
2.7 ~ normal(z,1);
18+
19+
// variable
20+
y ~ normal(0,1);
21+
22+
// fun
23+
m ~ lkj_corr_cholesky(2.0);
24+
(m + m) ~ lkj_corr_cholesky(2.0);
25+
(m - m) ~ lkj_corr_cholesky(2.0);
26+
(v + v) ~ normal(0,1);
27+
(v - v) ~ normal(0,1);
28+
block(m,1,1,1,1) ~ lkj_corr_cholesky(2.0);
29+
col(m,1) ~ normal(0,1);
30+
cols(m) ~ normal(0,1);
31+
row(m,1) ~ normal(0,1);
32+
rows(m) ~ normal(0,1);
33+
diagonal(m) ~ normal(0,1);
34+
head(v,2) ~ normal(0,1);
35+
negative_infinity() ~ normal(0,1);
36+
not_a_number() ~ normal(0,1);
37+
rep_matrix(1,3,3) ~ lkj_corr_cholesky(2.0);
38+
(v')' ~ normal(0,1);
39+
positive_infinity() ~ normal(0,1);
40+
segment(v,2,4) ~ normal(0,1);
41+
sum(v) ~ normal(0,1);
42+
tail(v,3) ~ normal(0,1);
43+
to_vector(m) ~ normal(0,1);
44+
45+
// index_op
46+
v[1] ~ normal(0,1);
47+
m[1] ~ normal(0,1);
48+
m[1,2] ~ normal(0,1);
49+
a[1,2] ~ normal(0,1);
50+
a[1][2] ~ normal(0,1);
51+
52+
// binary_op
53+
y + z ~ normal(0,1);
54+
y - z ~ normal(0,1);
55+
1 * z ~ normal(0,1);
56+
z * 1 ~ normal(0,1);
57+
1 / (1 / z) ~ normal(0,1);
58+
y + ((z / 2) * 3) ~ normal(0,1);
59+
2.0 * 3 ~ normal(y,1);
60+
61+
// unary_op
62+
(-y) ~ normal(0,1);
63+
-(-y) ~ normal(0,1);
64+
65+
// literals
66+
[1] ~ normal(0,1);
67+
[y + y] ~ normal(0,1);
68+
to_vector({1}) ~ normal(0,1);
69+
70+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
parameters {
2+
real y;
3+
}
4+
model {
5+
log(y) ~ normal(0,1);
6+
// triggers the warning but isn't strictly a ~ statement
7+
target += normal_lpdf(log(y) | 0, 1);
8+
}

test/integration/good/warning/validate_jacobian_warning2.stan renamed to test/integration/cli-args/warn-pedantic/jacobian_warning2.stan

File renamed without changes.

test/integration/good/warning/validate_jacobian_warning3.stan renamed to test/integration/cli-args/warn-pedantic/jacobian_warning3.stan

File renamed without changes.

test/integration/good/warning/validate_jacobian_warning4.stan renamed to test/integration/cli-args/warn-pedantic/jacobian_warning4.stan

File renamed without changes.

test/integration/good/warning/validate_jacobian_warning5.stan renamed to test/integration/cli-args/warn-pedantic/jacobian_warning5.stan

File renamed without changes.

test/integration/good/warning/validate_jacobian_warning6.stan renamed to test/integration/cli-args/warn-pedantic/jacobian_warning6.stan

File renamed without changes.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
parameters {
2+
real y;
3+
}
4+
model {
5+
[log(y)] ~ normal(0,1);
6+
log([y]) ~ normal(0,1);
7+
}

test/integration/good/warning/validate_jacobian_warning_user.stan renamed to test/integration/cli-args/warn-pedantic/jacobian_warning_user.stan

File renamed without changes.

0 commit comments

Comments
 (0)