1+ #include " MagneX.H"
2+ #include < torch/script.h>
3+
4+ using namespace amrex ;
5+
6+ void CalculateH_demag_ML (const Array<MultiFab, AMREX_SPACEDIM>& Mfield,
7+ torch::jit::script::Module& x_norm_module,
8+ torch::jit::script::Module& ml_module,
9+ torch::jit::script::Module& y_norm_module,
10+ Array<MultiFab, AMREX_SPACEDIM>& H_demagfield)
11+ {
12+ BL_PROFILE_VAR (" CalculateH_demag_ML()" , CalculateH_demag_ML);
13+
14+ for (MFIter mfi (Mfield[0 ], TilingIfNotGPU ()); mfi.isValid (); ++mfi) {
15+
16+ const Box& bx = mfi.validbox ();
17+
18+ const auto & Mx = Mfield[0 ].const_array (mfi);
19+ const auto & My = Mfield[1 ].const_array (mfi);
20+ const auto & Mz = Mfield[2 ].const_array (mfi);
21+
22+ auto Hx_demag = H_demagfield[0 ].array (mfi);
23+ auto Hy_demag = H_demagfield[1 ].array (mfi);
24+ auto Hz_demag = H_demagfield[2 ].array (mfi);
25+
26+ const IntVect bx_lo = bx.smallEnd ();
27+ const IntVect nbox = bx.size ();
28+
29+ #if AMREX_SPACEDIM == 2
30+ const int ncell = nbox[0 ] * nbox[1 ];
31+ #else
32+ const int ncell = nbox[0 ] * nbox[1 ] * nbox[2 ];
33+ #endif
34+
35+ // Host-visible (Managed) buffers filled on GPU
36+ amrex::Gpu::ManagedVector<Real> aux_Mx (ncell);
37+ amrex::Gpu::ManagedVector<Real> aux_My (ncell);
38+ amrex::Gpu::ManagedVector<Real> aux_Mz (ncell);
39+
40+ Real* AMREX_RESTRICT auxPtr_Mx = aux_Mx.dataPtr ();
41+ Real* AMREX_RESTRICT auxPtr_My = aux_My.dataPtr ();
42+ Real* AMREX_RESTRICT auxPtr_Mz = aux_Mz.dataPtr ();
43+
44+ // Fill aux buffers from MultiFab on GPU
45+ amrex::ParallelFor (bx, [=] AMREX_GPU_DEVICE (int i, int j, int k) noexcept {
46+ const int ii = i - bx_lo[0 ];
47+ const int jj = j - bx_lo[1 ];
48+
49+ #if AMREX_SPACEDIM == 2
50+ const int index = jj + ii * nbox[1 ];
51+ #else
52+ const int kk = k - bx_lo[2 ];
53+ const int index = kk + jj * nbox[2 ] + ii * nbox[2 ] * nbox[1 ];
54+ #endif
55+
56+ auxPtr_Mx[index] = Mx (i, j, k);
57+ auxPtr_My[index] = My (i, j, k);
58+ auxPtr_Mz[index] = Mz (i, j, k);
59+ });
60+
61+ // Make sure aux buffers are ready for from_blob on host side
62+ amrex::Gpu::streamSynchronize ();
63+
64+ // Wrap buffers as CPU tensors (no copy)
65+ at::Tensor inputs_torch_Mx = torch::from_blob (auxPtr_Mx, {ncell, 1 }, torch::kFloat64 );
66+ at::Tensor inputs_torch_My = torch::from_blob (auxPtr_My, {ncell, 1 }, torch::kFloat64 );
67+ at::Tensor inputs_torch_Mz = torch::from_blob (auxPtr_Mz, {ncell, 1 }, torch::kFloat64 );
68+
69+ // Reshape (assumes ncell == 128*128*4)
70+ at::Tensor reshaped_Mx = inputs_torch_Mx.reshape ({128 , 128 , 4 });
71+ at::Tensor reshaped_My = inputs_torch_My.reshape ({128 , 128 , 4 });
72+ at::Tensor reshaped_Mz = inputs_torch_Mz.reshape ({128 , 128 , 4 });
73+
74+ // Stack into [3, 128, 128, 4] then add batch dim -> [1, 3, 128, 128, 4]
75+ at::Tensor final_tensor_M = torch::stack ({reshaped_Mx, reshaped_My, reshaped_Mz}, 0 );
76+ final_tensor_M = final_tensor_M.to (torch::kCUDA ).to (torch::kFloat32 );
77+ final_tensor_M = final_tensor_M.unsqueeze (0 );
78+
79+ // Normalize -> model -> denormalize
80+ at::Tensor norm_torch = x_norm_module.get_method (" encode" )({final_tensor_M}).toTensor ();
81+ at::Tensor outputs_torch = ml_module.forward ({norm_torch}).toTensor ();
82+ at::Tensor denorm_torch = y_norm_module.get_method (" decode" )({outputs_torch}).toTensor ();
83+
84+ // Convert to float64 for accessor<Real,...> usage (same as your original)
85+ denorm_torch = denorm_torch.to (torch::kFloat64 );
86+
87+ // Extract H components: denorm_torch shape assumed [1, 3, 128, 128, 4]
88+ at::Tensor denorm_torch_Hx = denorm_torch.select (0 , 0 ).select (0 , 0 ).flatten ();
89+ at::Tensor denorm_torch_Hy = denorm_torch.select (0 , 0 ).select (0 , 1 ).flatten ();
90+ at::Tensor denorm_torch_Hz = denorm_torch.select (0 , 0 ).select (0 , 2 ).flatten ();
91+
92+ #ifdef AMREX_USE_CUDA
93+ auto denorm_torch_Hx_acc = denorm_torch_Hx.packed_accessor64 <Real, 1 >();
94+ auto denorm_torch_Hy_acc = denorm_torch_Hy.packed_accessor64 <Real, 1 >();
95+ auto denorm_torch_Hz_acc = denorm_torch_Hz.packed_accessor64 <Real, 1 >();
96+ #endif
97+
98+ // Copy tensor data back into demag field
99+ amrex::ParallelFor (bx, [=] AMREX_GPU_DEVICE (int i, int j, int k) noexcept {
100+ const int ii = i - bx_lo[0 ];
101+ const int jj = j - bx_lo[1 ];
102+ const int kk = k - bx_lo[2 ];
103+ const int index = kk + jj * nbox[2 ] + ii * nbox[2 ] * nbox[1 ];
104+
105+ Hx_demag (i, j, k) = denorm_torch_Hx_acc[index];
106+ Hy_demag (i, j, k) = denorm_torch_Hy_acc[index];
107+ Hz_demag (i, j, k) = denorm_torch_Hz_acc[index];
108+ });
109+
110+ amrex::Gpu::streamSynchronize ();
111+ }
112+ }
0 commit comments