Skip to content

Commit 3a87d3d

Browse files
committed
non ml compile
1 parent 9832eb8 commit 3a87d3d

6 files changed

Lines changed: 32 additions & 19 deletions

File tree

Exec/GNUmakefile

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# AMREX_HOME defines the directory in which we will find all the AMReX code.
2-
AMREX_HOME = ../../amrex
2+
AMREX_HOME ?= ../../amrex
3+
SUNDIALS_HOME ?= ../../sundials/instdir
34

45
DEBUG = FALSE
56
USE_MPI = TRUE
@@ -12,11 +13,12 @@ USE_FFT = TRUE
1213
USE_ML = FALSE
1314
TINY_PROFILE = FALSE
1415
PROFILE = FALSE
15-
1616
USE_SUNDIALS = FALSE
17-
SUNDIALS_HOME ?= ../../sundials/instdir
1817

1918
ifeq ($(USE_ML),TRUE)
19+
20+
CPPFLAGS += -DAMREX_USE_ML
21+
2022
# Define a macro for the C++ preprocessor
2123
DEFINES += -DML_ENABLE -D_GLIBCXX_USE_CXX11_ABI=1
2224

Source/Demag_ml.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#ifdef AMREX_USE_ML
2+
13
// MagneX_ML_Infer_Dynamic.cpp
24
#include "MagneX.H"
35
#include <torch/script.h>
@@ -318,4 +320,6 @@ void RunMLDemagOnBox(
318320

319321
// 5) Unpack back to MultiFab
320322
UnpackTensorToHfieldDynamic(denorm, H_demagfield, mfi, bx, b.expected_spatial);
321-
}
323+
}
324+
325+
#endif

Source/MagneX.H

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,7 @@ void WritePlotfile(MultiFab& Ms,
176176
const Real& time,
177177
const int& plt_step);
178178

179-
180-
// void CalculateH_demag_ML(const Array< MultiFab, AMREX_SPACEDIM> & Mfield,
181-
// torch::jit::script::Module& x_norm_module,
182-
// torch::jit::script::Module& ml_module,
183-
// torch::jit::script::Module& y_norm_module,
184-
// Array< MultiFab, AMREX_SPACEDIM> & H_demagfield);
185-
179+
#ifdef AMREX_USE_ML
186180

187181
at::Tensor PackMfieldToTensorDynamic(
188182
const amrex::Array<amrex::MultiFab, AMREX_SPACEDIM>& Mfield,
@@ -216,3 +210,5 @@ void MoveModuleToDevice(torch::jit::script::Module& m,
216210

217211
// Read expected_spatial = [nx, ny, nz] from normalizer module
218212
amrex::IntVect GetExpectedSpatial(torch::jit::script::Module& x_norm_module);
213+
214+
#endif

Source/MagneX.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ void InitializeMagneXNamespace() {
234234

235235
ml_enable = 0;
236236
pp.query("ml_enable",ml_enable);
237+
#ifndef AMREX_USE_ML
238+
if (ml_enable == 1) {
239+
amrex::Abort("ml_enable=1 requires USE_ML=TRUE");
240+
}
241+
#endif
237242

238243
diag_type = -1;
239244
pp.query("diag_type",diag_type);

Source/Make.package

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
CEXE_sources += Checkpoint.cpp
22
CEXE_sources += ComputeLLGRHS.cpp
33
CEXE_sources += Demagnetization.cpp
4+
CEXE_sources += Demag_ml.cpp
45
CEXE_sources += Diagnostics.cpp
56
CEXE_sources += EffectiveAnisotropyField.cpp
67
CEXE_sources += EffectiveDMIField.cpp
@@ -15,4 +16,3 @@ CEXE_headers += CartesianAlgorithm_K.H
1516
CEXE_headers += Demagnetization.H
1617
CEXE_headers += MagneX.H
1718
CEXE_headers += MagneX_namespace.H
18-
CEXE_sources += Demag_ml.cpp

Source/main.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "MagneX.H"
22
#include "Demagnetization.H"
3-
#include <torch/script.h>
43
#include <AMReX_MultiFab.H>
54
#include <AMReX_VisMF.H>
65
#include <AMReX_ParmParse.H>
@@ -11,8 +10,12 @@
1110

1211
#include <cmath>
1312

13+
#ifdef AMREX_USE_ML
14+
#include <torch/script.h>
1415
#include <ATen/cuda/CUDAContext.h> // for at::cuda::setDevice
1516
#include <c10/cuda/CUDAGuard.h>
17+
#endif
18+
1619
using namespace amrex;
1720
using namespace MagneX;
1821

@@ -60,10 +63,11 @@ void main_main ()
6063
Array<MultiFab, AMREX_SPACEDIM> LLG_RHS;
6164
Array<MultiFab, AMREX_SPACEDIM> LLG_RHS_pre;
6265
Array<MultiFab, AMREX_SPACEDIM> LLG_RHS_avg;
66+
#ifdef AMREX_USE_ML
6367
torch::jit::script::Module ml_module;
6468
torch::jit::script::Module x_norm_module;
6569
torch::jit::script::Module y_norm_module;
66-
70+
#endif
6771

6872
// Declare variables for hysteresis
6973
Real normalized_Mx;
@@ -114,6 +118,7 @@ void main_main ()
114118
// **********************************
115119
// // LOAD PYTORCH MODEL
116120
if (ml_enable == 1) {
121+
#ifdef AMREX_USE_ML
117122
BL_PROFILE_VAR("LoadPytorch",LoadPytorch);
118123

119124
// Load pytorch module via torch script
@@ -164,6 +169,7 @@ void main_main ()
164169
<< expected_spatial[2] << "\n";
165170

166171
Print() << "Model loaded.\n";
172+
#endif
167173
}
168174
else {
169175
Print() << "ML disabled. Skipping model load.\n";
@@ -449,7 +455,7 @@ void main_main ()
449455
if (demag_coupling == 1) {
450456
// demag_solver.CalculateH_demag(Mfield_old, H_demagfield);
451457
if (ml_enable == 1) {
452-
// CalculateH_demag_ML(Mfield_old, x_norm_module, ml_module, y_norm_module, H_demagfield);
458+
#ifdef AMREX_USE_ML
453459
for (amrex::MFIter mfi(Mfield_old[0], amrex::TilingIfNotGPU());
454460
mfi.isValid(); ++mfi)
455461
{
@@ -460,14 +466,14 @@ void main_main ()
460466
Mfield_old, mfi, bx, expected_spatial, device_id
461467
);
462468

463-
//
464469
at::Tensor norm = NormalizeInput(M_cuda_f32, x_norm_module);
465470
at::Tensor pred = MLForwardOnly(norm, ml_module);
466471
at::Tensor denorm_f64 = DenormalizeOutput(pred, y_norm_module);
467472

468473
// unpack: tensor -> MultiFab
469474
UnpackTensorToHfieldDynamic(denorm_f64, H_demagfield, mfi, bx, expected_spatial);
470475
}
476+
#endif
471477
} else {
472478
// demag_solver.CalculateH_demag(Mfield_old, H_demagfield);
473479
amrex::Gpu::streamSynchronize();
@@ -599,8 +605,7 @@ void main_main ()
599605
if (demag_coupling == 1) {
600606
// demag_solver.CalculateH_demag(Mfield, H_demagfield);
601607
if (ml_enable == 1) {
602-
// CalculateH_demag_ML(Mfield, x_norm_module, ml_module, y_norm_module, H_demagfield);
603-
608+
#ifdef AMREX_USE_ML
604609
for (amrex::MFIter mfi(Mfield_old[0], amrex::TilingIfNotGPU());
605610
mfi.isValid(); ++mfi)
606611
{
@@ -618,6 +623,7 @@ void main_main ()
618623
// unpack: tensor -> MultiFab
619624
UnpackTensorToHfieldDynamic(denorm_f64, H_demagfield, mfi, bx, expected_spatial);
620625
}
626+
#endif
621627
} else {
622628
// demag_solver.CalculateH_demag(Mfield, H_demagfield);
623629
amrex::Gpu::streamSynchronize();
@@ -810,7 +816,7 @@ void main_main ()
810816
if (fast_demag==1) {
811817
// demag_solver.CalculateH_demag(ar_state, H_demagfield);
812818
if (ml_enable == 1) {
813-
CalculateH_demag_ML(ar_state, x_norm_module, ml_module, y_norm_module, H_demagfield);
819+
amrex::Abort("add ML demag to fast dynamics");
814820
} else {
815821
demag_solver.CalculateH_demag(ar_state, H_demagfield);
816822
}

0 commit comments

Comments
 (0)