@@ -228,7 +228,8 @@ __launch_bounds__(unary_kernel_threads) __global__
228228 loader.load (tid, size);
229229#pragma unroll
230230 for (int i = 0 ; i < nvec; ++i) {
231- const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment ()));
231+ const size_t global_idx =
232+ (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment ()));
232233 if (global_idx >= size) continue ;
233234
234235 ComputeType val = static_cast <ComputeType>(loader.separate ()[i]);
@@ -332,7 +333,8 @@ __launch_bounds__(unary_kernel_threads) __global__
332333 grad_loader.load (tid, size);
333334#pragma unroll
334335 for (int i = 0 ; i < nvec; ++i) {
335- const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment ()));
336+ const size_t global_idx =
337+ (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment ()));
336338 if (global_idx >= size) continue ;
337339
338340 ComputeType val = static_cast <ComputeType>(loader.separate ()[i]);
@@ -466,19 +468,19 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out
466468 switch (align) {
467469 case Alignment::SAME_ALIGNED :
468470 unary_kernel<nvec, true , fp32, Param, OP ><<<grid, threads, 0 , stream>>>(
469- input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements,
470- offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
471+ input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
472+ first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
471473 break ;
472474 case Alignment::SAME_UNALIGNED :
473475 unary_kernel<nvec, false , fp32, Param, OP ><<<grid, threads, 0 , stream>>>(
474- input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements,
475- offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
476+ input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
477+ first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
476478 break ;
477479 case Alignment::DIFFERENT : {
478480 // If the pointers are aligned differently we cannot vectorize
479481 unary_kernel<1 , true , fp32, Param, OP ><<<grid, threads, 0 , stream>>>(
480- input, noop, output, scale, amax, scale_inv, params, N, N,
481- offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
482+ input, noop, output, scale, amax, scale_inv, params, N, N, offsets, first_dims,
483+ last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
482484 break ;
483485 }
484486 }
@@ -508,19 +510,19 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp
508510 switch (align) {
509511 case Alignment::SAME_ALIGNED :
510512 unary_grad_kernel<nvec, true , fp32, Param, OP ><<<grid, threads, 0 , stream>>>(
511- grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements,
512- offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
513+ grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
514+ first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
513515 break ;
514516 case Alignment::SAME_UNALIGNED :
515517 unary_grad_kernel<nvec, false , fp32, Param, OP ><<<grid, threads, 0 , stream>>>(
516- grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements,
517- offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
518+ grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
519+ first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
518520 break ;
519521 case Alignment::DIFFERENT : {
520522 // If the pointers are aligned differently we cannot vectorize
521523 unary_grad_kernel<1 , true , fp32, Param, OP ><<<grid, threads, 0 , stream>>>(
522- grad, input, output, scale, amax, scale_inv, params, N, N,
523- offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
524+ grad, input, output, scale, amax, scale_inv, params, N, N, offsets, first_dims,
525+ last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
524526 break ;
525527 }
526528 }
0 commit comments