Skip to content

Commit e2cba04

Browse files
committed
initial commit
1 parent b952477 commit e2cba04

7 files changed

Lines changed: 221 additions & 6 deletions

File tree

Exec/GNUmakefile

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# AMREX_HOME defines the directory in which we will find all the AMReX code.
2-
AMREX_HOME ?= ../../amrex
2+
AMREX_HOME = ../../amrex
33

44
DEBUG = FALSE
55
USE_MPI = TRUE
@@ -12,6 +12,32 @@ USE_FFT = TRUE
1212

1313
USE_SUNDIALS = FALSE
1414

15+
16+
# Pytorch directories
17+
ifeq ($(USE_CUDA),TRUE)
18+
PYTORCH_ROOT := ../../libtorch_cuda
19+
else
20+
PYTORCH_ROOT := ../../libtorch_cpu
21+
endif
22+
TORCH_LIBPATH = $(PYTORCH_ROOT)/lib
23+
24+
ifeq ($(USE_CUDA),TRUE)
25+
TORCH_LIBS = -ltorch -ltorch_cpu -lc10 -lc10_cuda -lcuda
26+
else
27+
TORCH_LIBS = -ltorch -ltorch_cpu -lc10
28+
endif
29+
30+
INCLUDE_LOCATIONS += $(PYTORCH_ROOT)/include \
31+
$(PYTORCH_ROOT)/include/torch/csrc/api/include
32+
LIBRARY_LOCATIONS += $(TORCH_LIBPATH)
33+
34+
DEFINES += -D_GLIBCXX_USE_CXX11_ABI=1
35+
ifeq ($(USE_CUDA),TRUE)
36+
LDFLAGS += -Xlinker "--no-as-needed,-rpath $(TORCH_LIBPATH) $(TORCH_LIBS)"
37+
else
38+
LDFLAGS += -Wl,--no-as-needed,-rpath=$(TORCH_LIBPATH) $(TORCH_LIBS)
39+
endif
40+
1541
include $(AMREX_HOME)/Tools/GNUMake/Make.defs
1642

1743
include ../Source/Make.package

Source/Demag_ml.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#include "MagneX.H"
2+
#include <torch/script.h>
3+
4+
using namespace amrex;
5+
6+
void CalculateH_demag_ML(const Array<MultiFab, AMREX_SPACEDIM>& Mfield,
7+
torch::jit::script::Module& x_norm_module,
8+
torch::jit::script::Module& ml_module,
9+
torch::jit::script::Module& y_norm_module,
10+
Array<MultiFab, AMREX_SPACEDIM>& H_demagfield)
11+
{
12+
BL_PROFILE_VAR("CalculateH_demag_ML()", CalculateH_demag_ML);
13+
14+
for (MFIter mfi(Mfield[0], TilingIfNotGPU()); mfi.isValid(); ++mfi) {
15+
16+
const Box& bx = mfi.validbox();
17+
18+
const auto& Mx = Mfield[0].const_array(mfi);
19+
const auto& My = Mfield[1].const_array(mfi);
20+
const auto& Mz = Mfield[2].const_array(mfi);
21+
22+
auto Hx_demag = H_demagfield[0].array(mfi);
23+
auto Hy_demag = H_demagfield[1].array(mfi);
24+
auto Hz_demag = H_demagfield[2].array(mfi);
25+
26+
const IntVect bx_lo = bx.smallEnd();
27+
const IntVect nbox = bx.size();
28+
29+
#if AMREX_SPACEDIM == 2
30+
const int ncell = nbox[0] * nbox[1];
31+
#else
32+
const int ncell = nbox[0] * nbox[1] * nbox[2];
33+
#endif
34+
35+
// Host-visible (Managed) buffers filled on GPU
36+
amrex::Gpu::ManagedVector<Real> aux_Mx(ncell);
37+
amrex::Gpu::ManagedVector<Real> aux_My(ncell);
38+
amrex::Gpu::ManagedVector<Real> aux_Mz(ncell);
39+
40+
Real* AMREX_RESTRICT auxPtr_Mx = aux_Mx.dataPtr();
41+
Real* AMREX_RESTRICT auxPtr_My = aux_My.dataPtr();
42+
Real* AMREX_RESTRICT auxPtr_Mz = aux_Mz.dataPtr();
43+
44+
// Fill aux buffers from MultiFab on GPU
45+
amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept {
46+
const int ii = i - bx_lo[0];
47+
const int jj = j - bx_lo[1];
48+
49+
#if AMREX_SPACEDIM == 2
50+
const int index = jj + ii * nbox[1];
51+
#else
52+
const int kk = k - bx_lo[2];
53+
const int index = kk + jj * nbox[2] + ii * nbox[2] * nbox[1];
54+
#endif
55+
56+
auxPtr_Mx[index] = Mx(i, j, k);
57+
auxPtr_My[index] = My(i, j, k);
58+
auxPtr_Mz[index] = Mz(i, j, k);
59+
});
60+
61+
// Make sure aux buffers are ready for from_blob on host side
62+
amrex::Gpu::streamSynchronize();
63+
64+
// Wrap buffers as CPU tensors (no copy)
65+
at::Tensor inputs_torch_Mx = torch::from_blob(auxPtr_Mx, {ncell, 1}, torch::kFloat64);
66+
at::Tensor inputs_torch_My = torch::from_blob(auxPtr_My, {ncell, 1}, torch::kFloat64);
67+
at::Tensor inputs_torch_Mz = torch::from_blob(auxPtr_Mz, {ncell, 1}, torch::kFloat64);
68+
69+
// Reshape (assumes ncell == 128*128*4)
70+
at::Tensor reshaped_Mx = inputs_torch_Mx.reshape({128, 128, 4});
71+
at::Tensor reshaped_My = inputs_torch_My.reshape({128, 128, 4});
72+
at::Tensor reshaped_Mz = inputs_torch_Mz.reshape({128, 128, 4});
73+
74+
// Stack into [3, 128, 128, 4] then add batch dim -> [1, 3, 128, 128, 4]
75+
at::Tensor final_tensor_M = torch::stack({reshaped_Mx, reshaped_My, reshaped_Mz}, 0);
76+
final_tensor_M = final_tensor_M.to(torch::kCUDA).to(torch::kFloat32);
77+
final_tensor_M = final_tensor_M.unsqueeze(0);
78+
79+
// Normalize -> model -> denormalize
80+
at::Tensor norm_torch = x_norm_module.get_method("encode")({final_tensor_M}).toTensor();
81+
at::Tensor outputs_torch = ml_module.forward({norm_torch}).toTensor();
82+
at::Tensor denorm_torch = y_norm_module.get_method("decode")({outputs_torch}).toTensor();
83+
84+
// Convert to float64 for accessor<Real,...> usage (same as your original)
85+
denorm_torch = denorm_torch.to(torch::kFloat64);
86+
87+
// Extract H components: denorm_torch shape assumed [1, 3, 128, 128, 4]
88+
at::Tensor denorm_torch_Hx = denorm_torch.select(0, 0).select(0, 0).flatten();
89+
at::Tensor denorm_torch_Hy = denorm_torch.select(0, 0).select(0, 1).flatten();
90+
at::Tensor denorm_torch_Hz = denorm_torch.select(0, 0).select(0, 2).flatten();
91+
92+
#ifdef AMREX_USE_CUDA
93+
auto denorm_torch_Hx_acc = denorm_torch_Hx.packed_accessor64<Real, 1>();
94+
auto denorm_torch_Hy_acc = denorm_torch_Hy.packed_accessor64<Real, 1>();
95+
auto denorm_torch_Hz_acc = denorm_torch_Hz.packed_accessor64<Real, 1>();
96+
#endif
97+
98+
// Copy tensor data back into demag field
99+
amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept {
100+
const int ii = i - bx_lo[0];
101+
const int jj = j - bx_lo[1];
102+
const int kk = k - bx_lo[2];
103+
const int index = kk + jj * nbox[2] + ii * nbox[2] * nbox[1];
104+
105+
Hx_demag(i, j, k) = denorm_torch_Hx_acc[index];
106+
Hy_demag(i, j, k) = denorm_torch_Hy_acc[index];
107+
Hz_demag(i, j, k) = denorm_torch_Hz_acc[index];
108+
});
109+
110+
amrex::Gpu::streamSynchronize();
111+
}
112+
}

Source/MagneX.H

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#ifdef AMREX_USE_CUDA
22
#include <cufft.h>
3+
#include <torch/script.h>
34
#else
45
#include <fftw3.h>
56
#ifdef AMREX_USE_MPI
@@ -184,3 +185,11 @@ void WritePlotfile(MultiFab& Ms,
184185
const Geometry& geom,
185186
const Real& time,
186187
const int& plt_step);
188+
189+
190+
void CalculateH_demag_ML(const Array< MultiFab, AMREX_SPACEDIM> & Mfield,
191+
torch::jit::script::Module& x_norm_module,
192+
torch::jit::script::Module& ml_module,
193+
torch::jit::script::Module& y_norm_module,
194+
Array< MultiFab, AMREX_SPACEDIM> & H_demagfield);
195+

Source/MagneX.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ AMREX_GPU_MANAGED int MagneX::demag_coupling;
132132
// 0 = FFTW (single-MPI), 1 = heFFTe (distributed)
133133
AMREX_GPU_MANAGED int MagneX::FFT_solver;
134134

135+
// ML flag
136+
int MagneX::ml_enable;
137+
138+
135139
void InitializeMagneXNamespace() {
136140

137141
BL_PROFILE_VAR("InitializeMagneXNamespace()",InitializeMagneXNameSpace);
@@ -228,6 +232,9 @@ void InitializeMagneXNamespace() {
228232
restart = -1;
229233
pp.query("restart",restart);
230234

235+
ml_enable = 0;
236+
pp.query("ml_enable",ml_enable);
237+
231238
diag_type = -1;
232239
pp.query("diag_type",diag_type);
233240

Source/MagneX_namespace.H

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ namespace MagneX {
5252

5353
extern int diag_type;
5454

55+
extern int ml_enable;
56+
5557
extern int timedependent_Hbias;
5658
extern int timedependent_alpha;
5759

Source/Make.package

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ CEXE_headers += CartesianAlgorithm_K.H
1515
CEXE_headers += Demagnetization.H
1616
CEXE_headers += MagneX.H
1717
CEXE_headers += MagneX_namespace.H
18+
CEXE_sources += Demag_ml.cpp

Source/main.cpp

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#include "MagneX.H"
22
#include "Demagnetization.H"
3-
3+
#include <torch/script.h>
44
#include <AMReX_MultiFab.H>
55
#include <AMReX_VisMF.H>
6-
6+
#include <AMReX_ParmParse.H>
77
#ifdef AMREX_USE_SUNDIALS
88
#include <AMReX_TimeIntegrator.H>
99
#endif
1010

1111
#include <cmath>
1212

13+
#include <ATen/cuda/CUDAContext.h> // for at::cuda::setDevice
14+
#include <c10/cuda/CUDAGuard.h>
1315
using namespace amrex;
1416
using namespace MagneX;
1517

@@ -57,6 +59,10 @@ void main_main ()
5759
Array<MultiFab, AMREX_SPACEDIM> LLG_RHS;
5860
Array<MultiFab, AMREX_SPACEDIM> LLG_RHS_pre;
5961
Array<MultiFab, AMREX_SPACEDIM> LLG_RHS_avg;
62+
torch::jit::script::Module ml_module;
63+
torch::jit::script::Module x_norm_module;
64+
torch::jit::script::Module y_norm_module;
65+
6066

6167
// Declare variables for hysteresis
6268
Real normalized_Mx;
@@ -100,6 +106,42 @@ void main_main ()
100106

101107
}
102108

109+
// **********************************
110+
// // LOAD PYTORCH MODEL
111+
112+
BL_PROFILE_VAR("LoadPytorch",LoadPytorch);
113+
114+
// Load pytorch module via torch script
115+
116+
117+
std::string ml_model_name;
118+
std::string x_normalizer_name;
119+
std::string y_normalizer_name;
120+
121+
ParmParse pp_ml;
122+
pp_ml.query("ml_model_name", ml_model_name);
123+
pp_ml.query("x_normalizer_name", x_normalizer_name);
124+
pp_ml.query("y_normalizer_name", y_normalizer_name);
125+
126+
amrex::Print()<<"\n"<<ml_model_name<<"\n";
127+
amrex::Print()<<x_normalizer_name<<"\n";
128+
amrex::Print()<<y_normalizer_name<<"\n";
129+
130+
int dev_id = amrex::Gpu::Device::deviceId();
131+
c10::cuda::CUDAGuard device_guard(dev_id);
132+
torch::Device dev(torch::kCUDA, dev_id);
133+
try {
134+
// Deserialize the ScriptModule from a file using torch::jit::load().
135+
ml_module = torch::jit::load(ml_model_name, dev);
136+
x_norm_module = torch::jit::load(x_normalizer_name, dev);
137+
y_norm_module = torch::jit::load(y_normalizer_name, dev);
138+
}
139+
catch (const c10::Error& e) {
140+
amrex::Abort("Error loading the model\n");
141+
}
142+
143+
Print() << "Model loaded.\n";
144+
103145
// **********************************
104146
// SIMULATION SETUP
105147

@@ -378,7 +420,12 @@ void main_main ()
378420

379421
// Evolve H_demag
380422
if (demag_coupling == 1) {
381-
demag_solver.CalculateH_demag(Mfield_old, H_demagfield);
423+
// demag_solver.CalculateH_demag(Mfield_old, H_demagfield);
424+
if (ml_enable == 1) {
425+
CalculateH_demag_ML(Mfield_old, x_norm_module, ml_module, y_norm_module, H_demagfield);
426+
} else {
427+
demag_solver.CalculateH_demag(Mfield_old, H_demagfield);
428+
}
382429
}
383430

384431
if (exchange_coupling == 1) {
@@ -496,7 +543,13 @@ void main_main ()
496543

497544
// Poisson solve and H_demag computation with Mfield
498545
if (demag_coupling == 1) {
499-
demag_solver.CalculateH_demag(Mfield, H_demagfield);
546+
// demag_solver.CalculateH_demag(Mfield, H_demagfield);
547+
if (ml_enable == 1) {
548+
CalculateH_demag_ML(Mfield, x_norm_module, ml_module, y_norm_module, H_demagfield);
549+
} else {
550+
demag_solver.CalculateH_demag(Mfield, H_demagfield);
551+
}
552+
500553
}
501554

502555
if (exchange_coupling == 1) {
@@ -662,7 +715,12 @@ void main_main ()
662715
// H_demag
663716
if (demag_coupling == 1) {
664717
if (fast_demag==1) {
665-
demag_solver.CalculateH_demag(ar_state, H_demagfield);
718+
// demag_solver.CalculateH_demag(ar_state, H_demagfield);
719+
if (ml_enable == 1) {
720+
CalculateH_demag_ML(ar_state, x_norm_module, ml_module, y_norm_module, H_demagfield);
721+
} else {
722+
demag_solver.CalculateH_demag(ar_state, H_demagfield);
723+
}
666724
} else {
667725
for (int idim=0; idim<AMREX_SPACEDIM; ++idim) {
668726
H_demagfield[idim].setVal(0.);

0 commit comments

Comments
 (0)