Skip to content

Commit 62b3124

Browse files
authored
mcmc state (#2727)
1 parent b0fa681 commit 62b3124

9 files changed

Lines changed: 105 additions & 56 deletions

File tree

enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,10 +692,16 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
692692
The `selection` attribute determines which addresses to sample via HMC/NUTS.
693693
All sample addresses are included in the trace tensor for consistency.
694694

695-
Returns: (trace, diagnostics, rng)
695+
Returns: (trace, diagnostics, rng, final_position, final_gradient,
696+
final_potential_energy, final_step_size, final_inverse_mass_matrix)
696697
- trace: tensor<num_samples x position_size x f64>
697698
- diagnostics: tensor<num_samples x i1> - placeholder for future expansion
698699
- rng: updated RNG state
700+
- final_position: tensor<1 x position_size x f64> - position after last sample
701+
- final_gradient: tensor<1 x position_size x f64> - gradient at final position
702+
- final_potential_energy: tensor<f64> - potential energy at final position
703+
- final_step_size: tensor<f64> - adapted step size (after warmup)
704+
- final_inverse_mass_matrix: tensor - adapted inverse mass matrix (after warmup)
699705
}];
700706

701707
let arguments = (ins
@@ -721,13 +727,21 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
721727
OptionalAttr<FlatSymbolRefAttr>:$logpdf_fn,
722728
Optional<AnyRankedTensor>:$initial_position,
723729

730+
Optional<AnyRankedTensor>:$initial_gradient,
731+
Optional<AnyRankedTensor>:$initial_potential_energy,
732+
724733
DefaultValuedStrAttr<StrAttr, "">:$name
725734
);
726735

727736
let results = (outs
728737
AnyRankedTensor:$trace,
729738
AnyRankedTensor:$diagnostics,
730-
AnyType:$output_rng_state
739+
AnyType:$output_rng_state,
740+
AnyRankedTensor:$final_position,
741+
AnyRankedTensor:$final_gradient,
742+
AnyRankedTensor:$final_potential_energy,
743+
AnyRankedTensor:$final_step_size,
744+
AnyRankedTensor:$final_inverse_mass_matrix
731745
);
732746

733747
let assemblyFormat = [{
@@ -738,6 +752,8 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
738752
(`step_size` `=` $step_size^)?
739753
(`logpdf_fn` `=` $logpdf_fn^)?
740754
(`initial_position` `=` $initial_position^)?
755+
(`initial_gradient` `=` $initial_gradient^)?
756+
(`initial_potential_energy` `=` $initial_potential_energy^)?
741757
attr-dict `:` functional-type(operands, results)
742758
}];
743759

enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -763,11 +763,27 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
763763
}
764764
};
765765

766-
auto baseCtx =
767-
makeHMCContext(adaptedInvMass, adaptedMassMatrixSqrt, stepSize);
768-
auto initState = InitHMC(
769-
rewriter, loc, rngInput, baseCtx,
770-
hasLogpdfFn ? mcmcOp.getInitialPosition() : Value(), debugDump);
766+
Value currentQ, currentGrad, currentU, currentRng;
767+
768+
auto initialGrad = mcmcOp.getInitialGradient();
769+
auto initialPE = mcmcOp.getInitialPotentialEnergy();
770+
771+
if (hasLogpdfFn && initialGrad && initialPE) {
772+
currentQ = mcmcOp.getInitialPosition();
773+
currentGrad = initialGrad;
774+
currentU = initialPE;
775+
currentRng = rngInput;
776+
} else {
777+
auto baseCtx =
778+
makeHMCContext(adaptedInvMass, adaptedMassMatrixSqrt, stepSize);
779+
auto initState = InitHMC(
780+
rewriter, loc, rngInput, baseCtx,
781+
hasLogpdfFn ? mcmcOp.getInitialPosition() : Value(), debugDump);
782+
currentQ = initState.q0;
783+
currentGrad = initState.grad0;
784+
currentU = initState.U0;
785+
currentRng = initState.rng;
786+
}
771787

772788
auto runSampleStepWithStepSize =
773789
[&](OpBuilder &builder, Location loc, Value q, Value grad, Value U,
@@ -783,10 +799,6 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
783799
}
784800
};
785801

786-
Value currentQ = initState.q0;
787-
Value currentGrad = initState.grad0;
788-
Value currentU = initState.U0;
789-
Value currentRng = initState.rng;
790802
Value adaptedStepSize = stepSize;
791803

792804
auto runSampleStepWithInvMass =
@@ -804,6 +816,17 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
804816
}
805817
};
806818

819+
if (!adaptedInvMass) {
820+
adaptedInvMass = arith::ConstantOp::create(
821+
rewriter, loc, positionType,
822+
DenseElementsAttr::get(positionType,
823+
rewriter.getFloatAttr(elemType, 1.0)));
824+
adaptedMassMatrixSqrt = arith::ConstantOp::create(
825+
rewriter, loc, positionType,
826+
DenseElementsAttr::get(positionType,
827+
rewriter.getFloatAttr(elemType, 1.0)));
828+
}
829+
807830
if (numWarmup > 0) {
808831
auto c0 = arith::ConstantOp::create(
809832
rewriter, loc, i64TensorType,
@@ -1280,16 +1303,20 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
12801303
selectedAcceptedBuffer});
12811304

12821305
rewriter.setInsertionPointAfter(forLoopOp);
1306+
Value finalQ = forLoopOp.getResult(0);
1307+
Value finalGrad = forLoopOp.getResult(1);
1308+
Value finalU = forLoopOp.getResult(2);
1309+
Value finalRng = forLoopOp.getResult(3);
12831310
Value finalSamplesBuffer = forLoopOp.getResult(4);
12841311
Value finalAcceptedBuffer = forLoopOp.getResult(5);
1285-
Value finalRng = forLoopOp.getResult(3);
12861312

12871313
finalSamplesBuffer =
12881314
conditionalDump(rewriter, loc, finalSamplesBuffer,
12891315
"MCMC: collected samples", debugDump);
12901316

1291-
rewriter.replaceOp(mcmcOp,
1292-
{finalSamplesBuffer, finalAcceptedBuffer, finalRng});
1317+
rewriter.replaceOp(mcmcOp, {finalSamplesBuffer, finalAcceptedBuffer,
1318+
finalRng, finalQ, finalGrad, finalU,
1319+
adaptedStepSize, adaptedInvMass});
12931320

12941321
return success();
12951322
}

enzyme/test/MLIR/ProbProg/exp_transform.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ module {
1717
%inverse_mass_matrix = arith.constant dense<[[1.0, 0.0], [0.0, 1.0]]> : tensor<2x2xf64>
1818
%step_size = arith.constant dense<0.1> : tensor<f64>
1919

20-
%res:3 = enzyme.mcmc @test(%rng, %rate) given %init_trace
20+
%res:8 = enzyme.mcmc @test(%rng, %rate) given %init_trace
2121
inverse_mass_matrix = %inverse_mass_matrix
2222
step_size = %step_size
2323
{ hmc_config = #enzyme.hmc_config<trajectory_length = 1.000000e+00 : f64>, name = "hmc", selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], all_addresses = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], num_warmup = 0, num_samples = 1 }
24-
: (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>, tensor<2x2xf64>, tensor<f64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>)
24+
: (tensor<2xui64>, tensor<f64>, tensor<1x2xf64>, tensor<2x2xf64>, tensor<f64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
2525
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
2626
}
2727
}

enzyme/test/MLIR/ProbProg/hmc_diag_mass.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,25 @@ module {
1414
%init_trace = arith.constant dense<[[0.0, 0.0]]> : tensor<1x2xf64>
1515
%inv_mass = arith.constant dense<[2.0, 3.0]> : tensor<2xf64>
1616
%step_size = arith.constant dense<0.1> : tensor<f64>
17-
%res:3 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace
17+
%res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace
1818
inverse_mass_matrix = %inv_mass
1919
step_size = %step_size
2020
{ hmc_config = #enzyme.hmc_config<trajectory_length = 1.0>,
2121
name = "hmc_diag", selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], all_addresses = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], num_warmup = 0, num_samples = 1 }
22-
: (tensor<2xui64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>, tensor<2xf64>, tensor<f64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>)
22+
: (tensor<2xui64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>, tensor<2xf64>, tensor<f64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
2323
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
2424
}
2525

2626
func.func @nuts_diag_mass(%rng : tensor<2xui64>, %mean : tensor<f64>, %stddev : tensor<f64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) {
2727
%init_trace = arith.constant dense<[[0.0, 0.0]]> : tensor<1x2xf64>
2828
%inv_mass = arith.constant dense<[2.0, 3.0]> : tensor<2xf64>
2929
%step_size = arith.constant dense<0.1> : tensor<f64>
30-
%res:3 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace
30+
%res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace
3131
inverse_mass_matrix = %inv_mass
3232
step_size = %step_size
3333
{ nuts_config = #enzyme.nuts_config<max_tree_depth = 3, max_delta_energy = 1000.0, adapt_step_size = false, adapt_mass_matrix = false>,
3434
name = "nuts_diag", selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], all_addresses = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], num_warmup = 0, num_samples = 1 }
35-
: (tensor<2xui64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>, tensor<2xf64>, tensor<f64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>)
35+
: (tensor<2xui64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>, tensor<2xf64>, tensor<f64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor<f64>, tensor<f64>, tensor<1x2xf64>)
3636
return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>
3737
}
3838
}

enzyme/test/MLIR/ProbProg/hmc_kernel.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ module {
1212
func.func @hmc(%rng : tensor<2xui64>, %mean : tensor<f64>, %stddev : tensor<f64>) -> (tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>) {
1313
%init_trace = arith.constant dense<[[0.0]]> : tensor<1x1xf64>
1414
%step_size = arith.constant dense<0.1> : tensor<f64>
15-
%res:3 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace
15+
%res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace
1616
step_size = %step_size
1717
{ hmc_config = #enzyme.hmc_config<trajectory_length = 1.0>,
1818
name = "hmc", selection = [[#enzyme.symbol<1>]], all_addresses = [[#enzyme.symbol<1>]], num_warmup = 0, num_samples = 10 }
19-
: (tensor<2xui64>, tensor<f64>, tensor<f64>, tensor<1x1xf64>, tensor<f64>) -> (tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>)
19+
: (tensor<2xui64>, tensor<f64>, tensor<f64>, tensor<1x1xf64>, tensor<f64>) -> (tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<f64>, tensor<1x1xf64>)
2020
return %res#0, %res#1, %res#2 : tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>
2121
}
2222
}
@@ -69,15 +69,19 @@ module {
6969
// CHECK-NEXT: %[[RNG_M:.+]]:2 = enzyme.randomSplit %[[RNG_S]]#1 : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>)
7070
// CHECK-NEXT: %[[RNG_P:.+]], %[[P:.+]] = enzyme.random %[[RNG_M]]#0, %[[ZERO_F]], %[[ONE]] {rng_distribution = #enzyme<rng_distribution NORMAL>} : (tensor<2xui64>, tensor<f64>, tensor<f64>) -> (tensor<2xui64>, tensor<1x1xf64>)
7171
//
72-
// --- Initial kinetic energy K0 = 0.5 * p^T * p (contract over both dims for 2D) ---
73-
// CHECK-NEXT: %[[KE0_DOT:.+]] = enzyme.dot %[[P]], %[[P]] {{{.*}}lhs_contracting_dimensions = array<i64: 0, 1>{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<f64>
72+
// --- Transform momentum by mass matrix sqrt: p_transformed = massMatrixSqrt @ p ---
73+
// CHECK-NEXT: %[[P_XFORM:.+]] = enzyme.dot %[[P]], {{.+}} {{{.*}}lhs_contracting_dimensions = array<i64: 1>{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64>
74+
//
75+
// --- Initial kinetic energy K0 = 0.5 * p_transformed^T * M^-1 * p_transformed ---
76+
// CHECK-NEXT: %[[P_V:.+]] = enzyme.dot %[[P_XFORM]], {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64>
77+
// CHECK-NEXT: %[[KE0_DOT:.+]] = enzyme.dot %[[P_XFORM]], %[[P_V]] {{{.*}}lhs_contracting_dimensions = array<i64: 0, 1>{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<f64>
7478
// CHECK-NEXT: %[[KE0:.+]] = arith.mulf %[[KE0_DOT]], %[[HALF]] : tensor<f64>
7579
//
7680
// --- Initial Hamiltonian H0 = U + K ---
7781
// CHECK-NEXT: %[[H0:.+]] = arith.addf %[[U]], %[[KE0]] : tensor<f64>
7882
//
7983
// --- Leapfrog integration loop ---
80-
// CHECK-NEXT: %[[LF:.+]]:5 = enzyme.for_loop(%[[C0]] : tensor<i64>) to(%[[C10]] : tensor<i64>) step(%[[C1]] : tensor<i64>) iter_args(%[[Q]], %[[P]], %[[GRAD]], %[[U]], %[[RNG_S]]#2 : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<2xui64>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<2xui64> {
84+
// CHECK-NEXT: %[[LF:.+]]:5 = enzyme.for_loop(%[[C0]] : tensor<i64>) to(%[[C10]] : tensor<i64>) step(%[[C1]] : tensor<i64>) iter_args(%[[Q]], %[[P_XFORM]], %[[GRAD]], %[[U]], %[[RNG_S]]#2 : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<2xui64>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<2xui64> {
8185
// CHECK-NEXT: ^bb0(%[[LF_I:.+]]: tensor<i64>, %[[LF_Q:.+]]: tensor<1x1xf64>, %[[LF_P:.+]]: tensor<1x1xf64>, %[[LF_G:.+]]: tensor<1x1xf64>, %[[LF_U:.+]]: tensor<f64>, %[[LF_RNG:.+]]: tensor<2xui64>):
8286
//
8387
// --- Leapfrog: direction selection ---
@@ -91,7 +95,8 @@ module {
9195
// CHECK-NEXT: %[[P_HALF:.+]] = arith.subf %[[LF_P]], %[[GRAD_SCALED]] : tensor<1x1xf64>
9296
//
9397
// --- Leapfrog: full step position q_new = q + eps * M^-1 * p_half ---
94-
// CHECK-NEXT: %[[P_STEP:.+]] = arith.mulf %[[DIR_BC]], %[[P_HALF]] : tensor<1x1xf64>
98+
// CHECK-NEXT: %[[P_VINV:.+]] = enzyme.dot %[[P_HALF]], {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64>
99+
// CHECK-NEXT: %[[P_STEP:.+]] = arith.mulf %[[DIR_BC]], %[[P_VINV]] : tensor<1x1xf64>
95100
// CHECK-NEXT: %[[Q_NEW:.+]] = arith.addf %[[LF_Q]], %[[P_STEP]] : tensor<1x1xf64>
96101
//
97102
// --- Leapfrog: gradient at new position ---
@@ -112,8 +117,9 @@ module {
112117
// CHECK-NEXT: enzyme.yield %[[Q_NEW]], %[[P_NEW]], %[[AD_LF]]#2, %[[AD_LF]]#0, %[[AD_LF]]#1 : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<f64>, tensor<2xui64>
113118
// CHECK-NEXT: }
114119
//
115-
// --- Final kinetic energy K_new = 0.5 * p_new^T * p_new ---
116-
// CHECK-NEXT: %[[KE_DOT:.+]] = enzyme.dot %[[LF]]#1, %[[LF]]#1 {{{.*}}lhs_contracting_dimensions = array<i64: 0, 1>{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<f64>
120+
// --- Final kinetic energy K_new = 0.5 * p_new^T * M^-1 * p_new ---
121+
// CHECK-NEXT: %[[LF_P_V:.+]] = enzyme.dot %[[LF]]#1, {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64>
122+
// CHECK-NEXT: %[[KE_DOT:.+]] = enzyme.dot %[[LF]]#1, %[[LF_P_V]] {{{.*}}lhs_contracting_dimensions = array<i64: 0, 1>{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<f64>
117123
// CHECK-NEXT: %[[KE:.+]] = arith.mulf %[[KE_DOT]], %[[HALF]] : tensor<f64>
118124
//
119125
// --- Final Hamiltonian H_new = U_new + K_new ---

0 commit comments

Comments
 (0)