@@ -12,11 +12,11 @@ module {
1212 func.func @hmc (%rng : tensor <2 xui64 >, %mean : tensor <f64 >, %stddev : tensor <f64 >) -> (tensor <10 x1 xf64 >, tensor <10 xi1 >, tensor <2 xui64 >) {
1313 %init_trace = arith.constant dense <[[0.0 ]]> : tensor <1 x1 xf64 >
1414 %step_size = arith.constant dense <0.1 > : tensor <f64 >
15- %res:3 = enzyme.mcmc @test (%rng , %mean , %stddev ) given %init_trace
15+ %res:8 = enzyme.mcmc @test (%rng , %mean , %stddev ) given %init_trace
1616 step_size = %step_size
1717 { hmc_config = #enzyme.hmc_config <trajectory_length = 1.0 >,
1818 name = " hmc" , selection = [[#enzyme.symbol <1 >]], all_addresses = [[#enzyme.symbol <1 >]], num_warmup = 0 , num_samples = 10 }
19- : (tensor <2 xui64 >, tensor <f64 >, tensor <f64 >, tensor <1 x1 xf64 >, tensor <f64 >) -> (tensor <10 x1 xf64 >, tensor <10 xi1 >, tensor <2 xui64 >)
19+ : (tensor <2 xui64 >, tensor <f64 >, tensor <f64 >, tensor <1 x1 xf64 >, tensor <f64 >) -> (tensor <10 x1 xf64 >, tensor <10 xi1 >, tensor <2 xui64 >, tensor < 1 x 1 x f64 >, tensor < 1 x 1 x f64 >, tensor < f64 >, tensor < f64 >, tensor < 1 x 1 x f64 > )
2020 return %res#0 , %res#1 , %res#2 : tensor <10 x1 xf64 >, tensor <10 xi1 >, tensor <2 xui64 >
2121 }
2222}
@@ -69,15 +69,19 @@ module {
6969// CHECK-NEXT: %[[RNG_M:.+]]:2 = enzyme.randomSplit %[[RNG_S]]#1 : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>)
7070// CHECK-NEXT: %[[RNG_P:.+]], %[[P:.+]] = enzyme.random %[[RNG_M]]#0, %[[ZERO_F]], %[[ONE]] {rng_distribution = #enzyme<rng_distribution NORMAL>} : (tensor<2xui64>, tensor<f64>, tensor<f64>) -> (tensor<2xui64>, tensor<1x1xf64>)
7171//
72- // --- Initial kinetic energy K0 = 0.5 * p^T * p (contract over both dims for 2D) ---
73- // CHECK-NEXT: %[[KE0_DOT:.+]] = enzyme.dot %[[P]], %[[P]] {{{.*}}lhs_contracting_dimensions = array<i64: 0, 1>{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<f64>
72+ // --- Transform momentum by mass matrix sqrt: p_transformed = massMatrixSqrt @ p ---
73+ // CHECK-NEXT: %[[P_XFORM:.+]] = enzyme.dot %[[P]], {{.+}} {{{.*}}lhs_contracting_dimensions = array<i64: 1>{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64>
74+ //
75+ // --- Initial kinetic energy K0 = 0.5 * p_transformed^T * M^-1 * p_transformed ---
76+ // CHECK-NEXT: %[[P_V:.+]] = enzyme.dot %[[P_XFORM]], {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64>
77+ // CHECK-NEXT: %[[KE0_DOT:.+]] = enzyme.dot %[[P_XFORM]], %[[P_V]] {{{.*}}lhs_contracting_dimensions = array<i64: 0, 1>{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<f64>
7478// CHECK-NEXT: %[[KE0:.+]] = arith.mulf %[[KE0_DOT]], %[[HALF]] : tensor<f64>
7579//
7680// --- Initial Hamiltonian H0 = U + K ---
7781// CHECK-NEXT: %[[H0:.+]] = arith.addf %[[U]], %[[KE0]] : tensor<f64>
7882//
7983// --- Leapfrog integration loop ---
80- // CHECK-NEXT: %[[LF:.+]]:5 = enzyme.for_loop(%[[C0]] : tensor<i64>) to(%[[C10]] : tensor<i64>) step(%[[C1]] : tensor<i64>) iter_args(%[[Q]], %[[P ]], %[[GRAD]], %[[U]], %[[RNG_S]]#2 : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<2xui64>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<2xui64> {
84+ // CHECK-NEXT: %[[LF:.+]]:5 = enzyme.for_loop(%[[C0]] : tensor<i64>) to(%[[C10]] : tensor<i64>) step(%[[C1]] : tensor<i64>) iter_args(%[[Q]], %[[P_XFORM ]], %[[GRAD]], %[[U]], %[[RNG_S]]#2 : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<2xui64>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<2xui64> {
8185// CHECK-NEXT: ^bb0(%[[LF_I:.+]]: tensor<i64>, %[[LF_Q:.+]]: tensor<1x1xf64>, %[[LF_P:.+]]: tensor<1x1xf64>, %[[LF_G:.+]]: tensor<1x1xf64>, %[[LF_U:.+]]: tensor<f64>, %[[LF_RNG:.+]]: tensor<2xui64>):
8286//
8387// --- Leapfrog: direction selection ---
@@ -91,7 +95,8 @@ module {
9195// CHECK-NEXT: %[[P_HALF:.+]] = arith.subf %[[LF_P]], %[[GRAD_SCALED]] : tensor<1x1xf64>
9296//
9397// --- Leapfrog: full step position q_new = q + eps * M^-1 * p_half ---
94- // CHECK-NEXT: %[[P_STEP:.+]] = arith.mulf %[[DIR_BC]], %[[P_HALF]] : tensor<1x1xf64>
98+ // CHECK-NEXT: %[[P_VINV:.+]] = enzyme.dot %[[P_HALF]], {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64>
99+ // CHECK-NEXT: %[[P_STEP:.+]] = arith.mulf %[[DIR_BC]], %[[P_VINV]] : tensor<1x1xf64>
95100// CHECK-NEXT: %[[Q_NEW:.+]] = arith.addf %[[LF_Q]], %[[P_STEP]] : tensor<1x1xf64>
96101//
97102// --- Leapfrog: gradient at new position ---
@@ -112,8 +117,9 @@ module {
112117// CHECK-NEXT: enzyme.yield %[[Q_NEW]], %[[P_NEW]], %[[AD_LF]]#2, %[[AD_LF]]#0, %[[AD_LF]]#1 : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<2xui64>
113118// CHECK-NEXT: }
114119//
115- // --- Final kinetic energy K_new = 0.5 * p_new^T * p_new ---
116- // CHECK-NEXT: %[[KE_DOT:.+]] = enzyme.dot %[[LF]]#1, %[[LF]]#1 {{{.*}}lhs_contracting_dimensions = array<i64: 0, 1>{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<f64>
120+ // --- Final kinetic energy K_new = 0.5 * p_new^T * M^-1 * p_new ---
121+ // CHECK-NEXT: %[[LF_P_V:.+]] = enzyme.dot %[[LF]]#1, {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64>
122+ // CHECK-NEXT: %[[KE_DOT:.+]] = enzyme.dot %[[LF]]#1, %[[LF_P_V]] {{{.*}}lhs_contracting_dimensions = array<i64: 0, 1>{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<f64>
117123// CHECK-NEXT: %[[KE:.+]] = arith.mulf %[[KE_DOT]], %[[HALF]] : tensor<f64>
118124//
119125// --- Final Hamiltonian H_new = U_new + K_new ---
0 commit comments