Skip to content

Commit ceb830b

Browse files
authored
Merge pull request QMCPACK#5594 from ye-luo/update-du-gpu
Slight adjustment in batched delayed update
2 parents 4e48cd0 + 2872830 commit ceb830b

3 files changed

Lines changed: 26 additions & 19 deletions

File tree

src/QMCWaveFunctions/Fermion/DelayedUpdateBatched.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -645,11 +645,11 @@ class DelayedUpdateBatched
645645
//std::copy_n(Ainv[rowchanged], norb, V[delay_count]);
646646
compute::BLAS::copy_batched(blas_handle, norb, invRow_mw_ptr, 1, V_row_mw_ptr, 1, nw);
647647
// handle accepted walkers
648-
// the new Binv is [[X Y] [Z sigma]]
648+
// the new Binv is [[X y] [z sigma]]
649649
//BLAS::gemv('T', norb, delay_count + 1, cminusone, V.data(), norb, psiV.data(), 1, czero, p.data(), 1);
650650
compute::BLAS::gemv_batched(blas_handle, 'T', norb, delay_count, cminusone_vec.device_data(), V_mw_ptr, norb,
651651
phiVGL_mw_ptr, 1, czero_vec.device_data(), p_mw_ptr, 1, n_accepted);
652-
// Y
652+
// y
653653
//BLAS::gemv('T', delay_count, delay_count, sigma, Binv.data(), lda_Binv, p.data(), 1, czero, Binv.data() + delay_count,
654654
// lda_Binv);
655655
compute::BLAS::gemv_batched(blas_handle, 'T', delay_count, delay_count, ratio_inv_mw_ptr, Binv_mw_ptr, lda_Binv,

src/QMCWaveFunctions/detail/CUDA/matrix_update_helper.cu

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -321,20 +321,16 @@ __global__ void add_delay_list_save_sigma_VGL_kernel(int* const delay_list[],
321321

322322
if (iw < n_accepted)
323323
{
324-
// real accept, settle y and Z
324+
// real accept
325325
int* __restrict__ delay_list_iw = delay_list[iw];
326326
T* __restrict__ binvrow_iw = binv[iw] + delay_count * binv_lda;
327-
const T* __restrict__ phi_in_iw = phi_vgl_in[iw];
328-
T* __restrict__ phi_out_iw = phi_out[iw];
329-
T* __restrict__ dphi_out_iw = dphi_out[iw];
330-
T* __restrict__ d2phi_out_iw = d2phi_out[iw];
331-
332327
if (tid == 0)
333328
{
334329
delay_list_iw[delay_count] = rowchanged;
335330
binvrow_iw[delay_count] = ratio_inv[iw];
336331
}
337332

333+
// Settle z by applying the final resaling.
338334
const int num_delay_count_col_blocks = (delay_count + COLBS - 1) / COLBS;
339335
for (int ib = 0; ib < num_delay_count_col_blocks; ib++)
340336
{
@@ -343,6 +339,12 @@ __global__ void add_delay_list_save_sigma_VGL_kernel(int* const delay_list[],
343339
binvrow_iw[col_id] *= ratio_inv[iw];
344340
}
345341

342+
// Save VGL
343+
const T* __restrict__ phi_in_iw = phi_vgl_in[iw];
344+
T* __restrict__ phi_out_iw = phi_out[iw];
345+
T* __restrict__ dphi_out_iw = dphi_out[iw];
346+
T* __restrict__ d2phi_out_iw = d2phi_out[iw];
347+
346348
const int num_col_blocks = (norb + COLBS - 1) / COLBS;
347349
for (int ib = 0; ib < num_col_blocks; ib++)
348350
{
@@ -360,7 +362,7 @@ __global__ void add_delay_list_save_sigma_VGL_kernel(int* const delay_list[],
360362
}
361363
else
362364
{
363-
// fake accept. Set Y, Z with zero and x with 1
365+
// pseudo accept
364366
T* __restrict__ Urow_iw = phi_out[iw];
365367
const int num_blocks_norb = (norb + COLBS - 1) / COLBS;
366368
for (int ib = 0; ib < num_blocks_norb; ib++)
@@ -370,15 +372,17 @@ __global__ void add_delay_list_save_sigma_VGL_kernel(int* const delay_list[],
370372
Urow_iw[col_id] = T(0);
371373
}
372374

375+
// Set y to zero
373376
T* __restrict__ binv_iw = binv[iw];
374377
const int num_blocks_delay_count = (delay_count + COLBS - 1) / COLBS;
375378
for (int ib = 0; ib < num_blocks_delay_count; ib++)
376379
{
377380
const int col_id = ib * COLBS + tid;
378381
if (col_id < delay_count)
379-
binv_iw[delay_count * binv_lda + col_id] = binv_iw[delay_count + binv_lda * col_id] = T(0);
382+
binv_iw[delay_count + binv_lda * col_id] = T(0);
380383
}
381384

385+
// Set x to 1
382386
int* __restrict__ delay_list_iw = delay_list[iw];
383387
if (tid == 0)
384388
{

src/QMCWaveFunctions/detail/SYCL/matrix_update_helper.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,20 +233,16 @@ sycl::event add_delay_list_save_sigma_VGL_batched(sycl::queue& aq,
233233

234234
if (iw < n_accepted)
235235
{
236-
// real accept, settle y and Z
236+
// real accept
237237
int* __restrict__ delay_list_iw = delay_list[iw];
238238
T* __restrict__ binvrow_iw = binv[iw] + delay_count * binv_lda;
239-
const T* __restrict__ phi_in_iw = phi_vgl_in[iw];
240-
T* __restrict__ phi_out_iw = phi_out[iw];
241-
T* __restrict__ dphi_out_iw = dphi_out[iw];
242-
T* __restrict__ d2phi_out_iw = d2phi_out[iw];
243-
244239
if (tid == 0)
245240
{
246241
delay_list_iw[delay_count] = rowchanged;
247242
binvrow_iw[delay_count] = ratio_inv[iw];
248243
}
249244

245+
// Settle z
250246
const int num_delay_count_col_blocks = (delay_count + COLBS - 1) / COLBS;
251247
for (int ib = 0; ib < num_delay_count_col_blocks; ib++)
252248
{
@@ -255,6 +251,12 @@ sycl::event add_delay_list_save_sigma_VGL_batched(sycl::queue& aq,
255251
binvrow_iw[col_id] *= ratio_inv[iw];
256252
}
257253

254+
// Save VGL
255+
const T* __restrict__ phi_in_iw = phi_vgl_in[iw];
256+
T* __restrict__ phi_out_iw = phi_out[iw];
257+
T* __restrict__ dphi_out_iw = dphi_out[iw];
258+
T* __restrict__ d2phi_out_iw = d2phi_out[iw];
259+
258260
const int num_col_blocks = (norb + COLBS - 1) / COLBS;
259261
for (int ib = 0; ib < num_col_blocks; ib++)
260262
{
@@ -272,7 +274,7 @@ sycl::event add_delay_list_save_sigma_VGL_batched(sycl::queue& aq,
272274
}
273275
else
274276
{
275-
// fake accept. Set Y, Z with zero and x with 1
277+
// pseudo accept
276278
T* __restrict__ Urow_iw = phi_out[iw];
277279
const int num_blocks_norb = (norb + COLBS - 1) / COLBS;
278280
for (int ib = 0; ib < num_blocks_norb; ib++)
@@ -282,16 +284,17 @@ sycl::event add_delay_list_save_sigma_VGL_batched(sycl::queue& aq,
282284
Urow_iw[col_id] = T{};
283285
}
284286

287+
// Set y to zero
285288
T* __restrict__ binv_iw = binv[iw];
286289
const int num_blocks_delay_count = (delay_count + COLBS - 1) / COLBS;
287290
for (int ib = 0; ib < num_blocks_delay_count; ib++)
288291
{
289292
const int col_id = ib * COLBS + tid;
290293
if (col_id < delay_count)
291-
binv_iw[delay_count * binv_lda + col_id] = binv_iw[delay_count + binv_lda * col_id] =
292-
T(0);
294+
binv_iw[delay_count + binv_lda * col_id] = T(0);
293295
}
294296

297+
// Set x to 1
295298
int* __restrict__ delay_list_iw = delay_list[iw];
296299
if (tid == 0)
297300
{

0 commit comments

Comments
 (0)