11#include " MagneX.H"
22#include " Demagnetization.H"
3- #include < torch/script.h>
43#include < AMReX_MultiFab.H>
54#include < AMReX_VisMF.H>
65#include < AMReX_ParmParse.H>
1110
1211#include < cmath>
1312
13+ #ifdef AMREX_USE_ML
14+ #include < torch/script.h>
1415#include < ATen/cuda/CUDAContext.h> // for at::cuda::setDevice
1516#include < c10/cuda/CUDAGuard.h>
17+ #endif
18+
1619using namespace amrex ;
1720using namespace MagneX ;
1821
@@ -60,10 +63,11 @@ void main_main ()
6063 Array<MultiFab, AMREX_SPACEDIM> LLG_RHS;
6164 Array<MultiFab, AMREX_SPACEDIM> LLG_RHS_pre;
6265 Array<MultiFab, AMREX_SPACEDIM> LLG_RHS_avg;
66+ #ifdef AMREX_USE_ML
6367 torch::jit::script::Module ml_module;
6468 torch::jit::script::Module x_norm_module;
6569 torch::jit::script::Module y_norm_module;
66-
70+ # endif
6771
6872 // Declare variables for hysteresis
6973 Real normalized_Mx;
@@ -114,6 +118,7 @@ void main_main ()
114118 // **********************************
115119 // // LOAD PYTORCH MODEL
116120 if (ml_enable == 1 ) {
121+ #ifdef AMREX_USE_ML
117122 BL_PROFILE_VAR (" LoadPytorch" ,LoadPytorch);
118123
119124 // Load pytorch module via torch script
@@ -164,6 +169,7 @@ void main_main ()
164169 << expected_spatial[2 ] << " \n " ;
165170
166171 Print () << " Model loaded.\n " ;
172+ #endif
167173 }
168174 else {
169175 Print () << " ML disabled. Skipping model load.\n " ;
@@ -449,7 +455,7 @@ void main_main ()
449455 if (demag_coupling == 1 ) {
450456 // demag_solver.CalculateH_demag(Mfield_old, H_demagfield);
451457 if (ml_enable == 1 ) {
452- // CalculateH_demag_ML(Mfield_old, x_norm_module, ml_module, y_norm_module, H_demagfield);
458+ # ifdef AMREX_USE_ML
453459 for (amrex::MFIter mfi (Mfield_old[0 ], amrex::TilingIfNotGPU ());
454460 mfi.isValid (); ++mfi)
455461 {
@@ -460,14 +466,14 @@ void main_main ()
460466 Mfield_old, mfi, bx, expected_spatial, device_id
461467 );
462468
463- //
464469 at::Tensor norm = NormalizeInput (M_cuda_f32, x_norm_module);
465470 at::Tensor pred = MLForwardOnly (norm, ml_module);
466471 at::Tensor denorm_f64 = DenormalizeOutput (pred, y_norm_module);
467472
468473 // unpack: tensor -> MultiFab
469474 UnpackTensorToHfieldDynamic (denorm_f64, H_demagfield, mfi, bx, expected_spatial);
470475 }
476+ #endif
471477 } else {
472478 // demag_solver.CalculateH_demag(Mfield_old, H_demagfield);
473479 amrex::Gpu::streamSynchronize ();
@@ -599,8 +605,7 @@ void main_main ()
599605 if (demag_coupling == 1 ) {
600606 // demag_solver.CalculateH_demag(Mfield, H_demagfield);
601607 if (ml_enable == 1 ) {
602- // CalculateH_demag_ML(Mfield, x_norm_module, ml_module, y_norm_module, H_demagfield);
603-
608+ #ifdef AMREX_USE_ML
604609 for (amrex::MFIter mfi (Mfield_old[0 ], amrex::TilingIfNotGPU ());
605610 mfi.isValid (); ++mfi)
606611 {
@@ -618,6 +623,7 @@ void main_main ()
618623 // unpack: tensor -> MultiFab
619624 UnpackTensorToHfieldDynamic (denorm_f64, H_demagfield, mfi, bx, expected_spatial);
620625 }
626+ #endif
621627 } else {
622628 // demag_solver.CalculateH_demag(Mfield, H_demagfield);
623629 amrex::Gpu::streamSynchronize ();
@@ -810,7 +816,7 @@ void main_main ()
810816 if (fast_demag==1 ) {
811817 // demag_solver.CalculateH_demag(ar_state, H_demagfield);
812818 if (ml_enable == 1 ) {
813- CalculateH_demag_ML (ar_state, x_norm_module, ml_module, y_norm_module, H_demagfield );
819+ amrex::Abort ( " add ML demag to fast dynamics " );
814820 } else {
815821 demag_solver.CalculateH_demag (ar_state, H_demagfield);
816822 }
0 commit comments