Skip to content

Commit 5f128a5

Browse files
authored
Merge pull request #1082 from stan-dev/stancjs_optimizations
Add optimization flags, debug-* flags to stancjs
2 parents 6c2bc5e + 02f8258 commit 5f128a5

7 files changed

Lines changed: 123 additions & 21 deletions

File tree

src/frontend/Pretty_printing.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ let pp_typed_expression ppf e =
565565
let pretty_print_program ?(bare_functions = false) ?(line_length = 78)
566566
?(inline_includes = false) p =
567567
let result =
568-
strf "%a" (pp_program ~bare_functions ~line_length ~inline_includes) p in
568+
str "%a" (pp_program ~bare_functions ~line_length ~inline_includes) p in
569569
check_correctness ~bare_functions p result ;
570570
result
571571

src/stanc/stanc.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ let use_file filename =
301301
if !dump_opt_mir_pretty then Program.Typed.pp Format.std_formatter opt ;
302302
opt in
303303
if !output_file = "" then output_file := remove_dotstan !model_file ^ ".hpp" ;
304-
let cpp = Fmt.strf "%a" Stan_math_code_gen.pp_prog opt_mir in
304+
let cpp = Fmt.str "%a" Stan_math_code_gen.pp_prog opt_mir in
305305
Out_channel.write_all !output_file ~data:cpp ;
306306
if !print_model_cpp then print_endline cpp )
307307

src/stancjs/stancjs.ml

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ let stan2cpp model_name model_string is_flag_set flag_val =
6666
flag_val "max-line-length"
6767
|> Option.map ~f:int_of_string
6868
|> Option.value ~default:78 in
69+
let mir = Ast_to_Mir.trans_prog model_name typed_ast in
70+
let tx_mir = Transform_Mir.trans_prog mir in
6971
if is_flag_set "auto-format" || is_flag_set "print-canonical" then
7072
r.return
7173
( Result.Ok
@@ -77,16 +79,47 @@ let stan2cpp model_name model_string is_flag_set flag_val =
7779
canonicalizer_settings ) )
7880
, warnings
7981
, [] ) ;
82+
if is_flag_set "debug-mir" then
83+
r.return
84+
( Result.Ok
85+
(Sexp.to_string_hum [%sexp (mir : Middle.Program.Typed.t)])
86+
, warnings
87+
, [] ) ;
88+
if is_flag_set "debug-mir-pretty" then
89+
r.return (Result.Ok (Fmt.str "%a" Program.Typed.pp mir), warnings, []) ;
8090
if is_flag_set "debug-generate-data" then
8191
r.return
8292
( Result.Ok (Debug_data_generation.print_data_prog typed_ast)
8393
, warnings
8494
, [] ) ;
85-
let mir = Ast_to_Mir.trans_prog model_name typed_ast in
86-
let tx_mir = Transform_Mir.trans_prog mir in
8795
let opt_mir =
88-
if is_flag_set "O" then Optimize.optimization_suite tx_mir else tx_mir
89-
in
96+
let opt_lvl =
97+
if is_flag_set "O0" then Optimize.O0
98+
else if is_flag_set "O1" then Optimize.O1
99+
else if is_flag_set "Oexperimental" || is_flag_set "O" then
100+
Optimize.Oexperimental
101+
else Optimize.O0 in
102+
Optimize.optimization_suite
103+
~settings:(Optimize.level_optimizations opt_lvl)
104+
tx_mir in
105+
if is_flag_set "debug-optimized-mir" then
106+
r.return
107+
( Result.Ok
108+
(Sexp.to_string_hum [%sexp (opt_mir : Middle.Program.Typed.t)])
109+
, warnings
110+
, [] ) ;
111+
if is_flag_set "debug-optimized-mir-pretty" then
112+
r.return
113+
(Result.Ok (Fmt.str "%a" Program.Typed.pp opt_mir), warnings, []) ;
114+
if is_flag_set "debug-transformed-mir" then
115+
r.return
116+
( Result.Ok
117+
(Sexp.to_string_hum [%sexp (tx_mir : Middle.Program.Typed.t)])
118+
, warnings
119+
, [] ) ;
120+
if is_flag_set "debug-transformed-mir-pretty" then
121+
r.return
122+
(Result.Ok (Fmt.str "%a" Program.Typed.pp tx_mir), warnings, []) ;
90123
let cpp = Fmt.str "%a" Stan_math_code_gen.pp_prog opt_mir in
91124
let uninit_warnings =
92125
if is_flag_set "warn-uninitialized" then
@@ -141,4 +174,9 @@ let stan2cpp_wrapped name code (flags : Js.string_array Js.t Js.opt) =
141174
(warnings @ pedantic_mode_warnings) in
142175
wrap_result ?printed_filename result ~warnings
143176

144-
let () = Js.export "stanc" stan2cpp_wrapped
177+
let dump_stan_math_signatures () =
178+
Js.string @@ Fmt.str "%a" Stan_math_signatures.pretty_print_all_math_sigs ()
179+
180+
let () =
181+
Js.export "dump_stan_math_signatures" dump_stan_math_signatures ;
182+
Js.export "stanc" stan2cpp_wrapped

test/stancjs/debug.js

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
var stanc = require('../../src/stancjs/stancjs.bc.js');
2+
var utils = require("./utils/utils.js");
3+
4+
let basic_model = `
5+
parameters {
6+
real y;
7+
}
8+
model {
9+
y~std_normal();
10+
}
11+
`
12+
13+
let debug_mir_test = stanc.stanc("basic", basic_model, ["debug-mir"]);
14+
var ind = debug_mir_test.result.search("#include \\<stan/model/");
15+
console.assert(ind < 0, "ERROR: MIR printing is not valid.")
16+
17+
18+
let debug_mir_pretty_test = stanc.stanc("basic", basic_model, ["debug-mir-pretty"]);
19+
var ind = debug_mir_pretty_test.result.search("#include \\<stan/model/");
20+
console.assert(ind < 0, "ERROR: MIR pretty printing is not valid.")
21+
22+
23+
let debug_opt_mir_test = stanc.stanc("basic", basic_model, ["01", "debug-optimized-mir"]);
24+
var ind = debug_opt_mir_test.result.search("#include \\<stan/model/");
25+
console.assert(ind < 0, "ERROR: Optimized MIR printing is not valid.")
26+
27+
28+
let debug_opt_mir_pretty_test = stanc.stanc("basic", basic_model, ["01", "debug-optimized-mir-pretty"]);
29+
var ind = debug_opt_mir_pretty_test.result.search("#include \\<stan/model/");
30+
console.assert(ind < 0, "ERROR: Optimized MIR pretty printing is not valid.")
31+
32+
let debug_tx_mir_test = stanc.stanc("basic", basic_model, ["debug-transformed-mir"]);
33+
var ind = debug_tx_mir_test.result.search("#include \\<stan/model/");
34+
console.assert(ind < 0, "ERROR: Transformed MIR printing is not valid.")
35+
36+
37+
let debug_tx_mir_pretty_test = stanc.stanc("basic", basic_model, ["debug-transformed-mir-pretty"]);
38+
var ind = debug_tx_mir_pretty_test.result.search("#include \\<stan/model/");
39+
console.assert(ind < 0, "ERROR: Transformed MIR pretty printing is not valid.")
40+

test/stancjs/math_sigs.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
var stanc = require('../../src/stancjs/stancjs.bc.js');
2+
var utils = require("./utils/utils.js");
3+
4+
let stan_math_sigs = stanc.dump_stan_math_signatures();
5+
console.assert(stan_math_sigs.includes("bernoulli_cdf(int, real) => real"), "Failed to find bernoulli signature!")
6+
console.assert(stan_math_sigs.includes("zeros_array(int) => array[] real"), "Failed to find zeros_array signature!")

test/stancjs/optimization.js

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,32 @@ transformed data {
1111
`
1212
var opt_test = stanc.stanc("optimization", opt_model, []);
1313
var ind = opt_test.result.search("int t = 1; t <= 5; \\+\\+t");
14-
if (ind == -1) {
15-
console.log("ERROR: Optimization without the O flag!")
16-
}
14+
console.assert(ind > -1, "ERROR: Optimization without the O flag!")
1715

1816
var opt_test = stanc.stanc("optimization", opt_model, ["O"]);
1917
var ind = opt_test.result.search("int t = 1; t <= 5; \\+\\+t");
20-
if (ind > -1) {
21-
console.log("ERROR: No optimization without the O flag!")
18+
console.assert(ind < 0, "ERROR: No optimization without the O flag!")
19+
20+
var ad_model = `
21+
data {
22+
matrix[10, 10] X_d;
23+
}
24+
parameters {
25+
matrix[10, 10] X_p;
26+
}
27+
28+
transformed parameters {
29+
matrix[10, 10] X_tp1 = X_d;
2230
}
31+
`
32+
33+
var opt_test = stanc.stanc("optimization", ad_model, ["O1"]);
34+
var ind = opt_test.result.search("\\<double, -1, -1\\> X_tp1");
35+
console.assert(ind > -1, "ERROR: No AD optimization with the O1 flag!")
36+
37+
var opt_test = stanc.stanc("optimization", ad_model, ["O0"]);
38+
var ind = opt_test.result.search("\\<local_scalar_t__, -1, -1\\> X_tp1");
39+
console.assert(ind > -1, "ERROR: AD optimization without the O1 flag!")
2340

2441
var glm_model = `
2542
data {
@@ -42,16 +59,14 @@ data {
4259
var no_opencl_test = stanc.stanc("no_opencl", glm_model);
4360
utils.print_error(no_opencl_test)
4461
var ind = no_opencl_test.result.search("matrix_cl<int> y_opencl__");
45-
if (ind > -1) {
46-
console.log("ERROR: OpenCL code found without the use-opencl flag!")
47-
}
62+
console.assert(ind < 0, "ERROR: OpenCL code found without the use-opencl flag!")
63+
4864

4965
var opencl_test = stanc.stanc("opencl", glm_model, ["use-opencl"]);
5066
utils.print_error(opencl_test)
5167
var ind = opencl_test.result.search("matrix_cl<int> y_opencl__");
52-
if (ind == -1) {
53-
console.log("ERROR: No OpenCL code found with the use-opencl flag!")
54-
}
68+
console.assert(ind > -1, "ERROR: No OpenCL code found with the use-opencl flag!")
69+
5570

5671
// multiple flags
5772

@@ -86,6 +101,5 @@ utils.print_error(opencl_test)
86101
var opencl_test = stanc.stanc("opencl", glm_model2, ["use-opencl", "allow-undefined"]);
87102
utils.print_error(opencl_test)
88103
var ind = opencl_test.result.search("matrix_cl<int> y_opencl__");
89-
if (ind == -1) {
90-
console.log("ERROR: No OpenCL code found with the use-opencl flag!")
91-
}
104+
console.assert(ind > -1, "ERROR: No OpenCL code found with the use-opencl flag!")
105+

test/stancjs/stancjs.expected

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ dim(z) = (3, 4)
6767
dim(w) = (3)
6868
dim(p) = (4, 3)
6969

70+
$ node debug.js
71+
7072
$ node filename.js
7173
Semantic error in 'good_filename', line 6, column 4 to column 5:
7274
Identifier 'z' not in scope.
@@ -179,6 +181,8 @@ $ node info.js
179181
"distributions": [ ],
180182
"included_files": [ ] }
181183

184+
$ node math_sigs.js
185+
182186
$ node optimization.js
183187
Semantic error in 'string', line 3, column 8 to column 11:
184188
Function is declared without specifying a definition.

0 commit comments

Comments
 (0)