Skip to content

Commit 0432d95

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d5fb0bf commit 0432d95

4 files changed

Lines changed: 45 additions & 42 deletions

File tree

docker_build_and_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,4 @@ docker run --gpus all -it --rm \
8282
echo "=== Running operator tests ==="
8383
cd /workspace/TransformerEngine/tests/cpp
8484
./build/operator/test_operator "$@"
85-
' _ "${TEST_ARGS[@]}"
85+
' _ "${TEST_ARGS[@]}"

patch_swizzle.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
99
__global__ void __launch_bounds__(TB_DIM* TB_DIM)
1010
grouped_swizzle_scaling_variable_shape_kernel(
11-
const void* input,
12-
void* output,
11+
const void* input,
12+
void* output,
1313
const int64_t* m_array,
1414
const int64_t* k_array,
1515
const int* block_offsets,
@@ -42,23 +42,23 @@
4242
if (tensor_id == -1) return;
4343
4444
int local_block_id = linear_block_id - block_offsets[tensor_id];
45-
45+
4646
size_t M = rowwise ? m_array[tensor_id] : k_array[tensor_id];
4747
size_t K = rowwise ? k_array[tensor_id] : m_array[tensor_id];
48-
48+
4949
size_t padded_m = round_up_to_multiple(M, 128);
5050
size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast<size_t>(MXFP8_BLOCK_SIZE)), 4);
51-
51+
5252
int num_tiles_m = padded_m / SF_TILE_DIM_M;
5353
int num_tiles_k = padded_k / SF_TILE_DIM_K;
54-
54+
5555
int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1));
5656
if (vec_load_size == 3) vec_load_size = 1;
5757
int n_tiles_in_tb = TB_DIM * vec_load_size;
5858
5959
int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM);
6060
int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size);
61-
61+
6262
int block_x = local_block_id % grid_dim_x;
6363
int block_y = local_block_id / grid_dim_x;
6464
@@ -71,29 +71,29 @@
7171
if (rowwise) {
7272
if (vec_load_size == 4) {
7373
swizzle_row_scaling_kernel_impl<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>(
74-
input_base, output_base, padded_m, padded_k, original_M, original_K,
74+
input_base, output_base, padded_m, padded_k, original_M, original_K,
7575
block_x, block_y, grid_dim_x, grid_dim_y);
7676
} else if (vec_load_size == 2) {
7777
swizzle_row_scaling_kernel_impl<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>(
78-
input_base, output_base, padded_m, padded_k, original_M, original_K,
78+
input_base, output_base, padded_m, padded_k, original_M, original_K,
7979
block_x, block_y, grid_dim_x, grid_dim_y);
8080
} else {
8181
swizzle_row_scaling_kernel_impl<int, SF_TILE_DIM_M, SF_TILE_DIM_K>(
82-
input_base, output_base, padded_m, padded_k, original_M, original_K,
82+
input_base, output_base, padded_m, padded_k, original_M, original_K,
8383
block_x, block_y, grid_dim_x, grid_dim_y);
8484
}
8585
} else {
8686
if (vec_load_size == 4) {
8787
swizzle_col_scaling_kernel_impl<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>(
88-
input_base, output_base, padded_m, padded_k, original_M, original_K,
88+
input_base, output_base, padded_m, padded_k, original_M, original_K,
8989
block_x, block_y, grid_dim_x, grid_dim_y);
9090
} else if (vec_load_size == 2) {
9191
swizzle_col_scaling_kernel_impl<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>(
92-
input_base, output_base, padded_m, padded_k, original_M, original_K,
92+
input_base, output_base, padded_m, padded_k, original_M, original_K,
9393
block_x, block_y, grid_dim_x, grid_dim_y);
9494
} else {
9595
swizzle_col_scaling_kernel_impl<int, SF_TILE_DIM_M, SF_TILE_DIM_K>(
96-
input_base, output_base, padded_m, padded_k, original_M, original_K,
96+
input_base, output_base, padded_m, padded_k, original_M, original_K,
9797
block_x, block_y, grid_dim_x, grid_dim_y);
9898
}
9999
}
@@ -113,34 +113,34 @@
113113
if (blockIdx.x == 0 && threadIdx.x == 0) {
114114
int current_block_offset = 0;
115115
size_t current_scale_offset = 0;
116-
116+
117117
for (size_t i = 0; i < num_tensors; ++i) {
118118
block_offsets[i] = current_block_offset;
119119
scale_offsets[i] = current_scale_offset;
120-
120+
121121
size_t m = rowwise ? m_array[i] : k_array[i];
122122
size_t k = rowwise ? k_array[i] : m_array[i];
123-
123+
124124
size_t padded_m = round_up_to_multiple(m, 128);
125125
size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast<size_t>(MXFP8_BLOCK_SIZE)), 4);
126-
126+
127127
int num_tiles_m = padded_m / 128;
128128
int num_tiles_k = padded_k / 4;
129-
129+
130130
int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1));
131131
if (vec_load_size == 3) vec_load_size = 1;
132-
132+
133133
int blocks_m = num_tiles_m;
134134
int blocks_k = DIVUP(num_tiles_k, TB_DIM * vec_load_size);
135135
if (!rowwise) {
136136
blocks_m = DIVUP(num_tiles_m, vec_load_size);
137137
blocks_k = DIVUP(num_tiles_k, TB_DIM);
138138
}
139-
139+
140140
current_block_offset += blocks_m * blocks_k;
141141
current_scale_offset += padded_m * padded_k * scale_elem_size;
142142
}
143-
143+
144144
block_offsets[num_tensors] = current_block_offset;
145145
scale_offsets[num_tensors] = current_scale_offset;
146146
*total_blocks = current_block_offset;
@@ -150,7 +150,10 @@
150150
151151
namespace transformer_engine {
152152
"""
153-
content = content.replace("namespace transformer_engine {\n\nvoid swizzle_grouped_scaling_factors", kernels_code + "\nvoid swizzle_grouped_scaling_factors")
153+
content = content.replace(
154+
"namespace transformer_engine {\n\nvoid swizzle_grouped_scaling_factors",
155+
kernels_code + "\nvoid swizzle_grouped_scaling_factors",
156+
)
154157

155158
# 2. Modify swizzle_grouped_scaling_factors
156159
old_func = """void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output,
@@ -206,7 +209,7 @@
206209
auto launch_grouped_swizzle_variable = [&](bool rowwise) {
207210
const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype)
208211
: typeToSize(input->columnwise_scale_inv.dtype);
209-
212+
210213
compute_grouped_swizzle_setup<<<1, 1, 0, stream>>>(
211214
m_array, k_array, d_block_offsets, d_scale_offsets, d_total_blocks,
212215
d_global_counter, num_tensors, rowwise, scale_elem_size);
@@ -215,7 +218,7 @@
215218
grouped_swizzle_scaling_variable_shape_kernel<SF_TILE_DIM_M, SF_TILE_DIM_K>,
216219
cudaFuncAttributeMaxDynamicSharedMemorySize, max_slm_size));
217220
218-
int persistent_blocks = 108 * 8;
221+
int persistent_blocks = 108 * 8;
219222
dim3 num_blocks(persistent_blocks);
220223
221224
const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr;
@@ -257,4 +260,3 @@
257260

258261
with open("transformer_engine/common/swizzle/swizzle.cu", "w") as f:
259262
f.write(content)
260-

patch_swizzle_cpp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
});"""
1111

1212
new_code = """ swizzle_output.set_with_gemm_swizzled_scales(true);
13-
13+
1414
size_t num_tensors = input.num_tensors();
1515
size_t workspace_size = (num_tensors + 2) * sizeof(int) + (num_tensors + 1) * sizeof(size_t);
1616
workspace_size = roundup(workspace_size, 256);
@@ -40,4 +40,3 @@
4040

4141
with open("transformer_engine/pytorch/csrc/extensions/swizzle.cpp", "w") as f:
4242
f.write(content)
43-

transformer_engine/common/util/vectorized_pointwise.h

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ __launch_bounds__(unary_kernel_threads) __global__
228228
loader.load(tid, size);
229229
#pragma unroll
230230
for (int i = 0; i < nvec; ++i) {
231-
const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment()));
231+
const size_t global_idx =
232+
(aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment()));
232233
if (global_idx >= size) continue;
233234

234235
ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
@@ -332,7 +333,8 @@ __launch_bounds__(unary_kernel_threads) __global__
332333
grad_loader.load(tid, size);
333334
#pragma unroll
334335
for (int i = 0; i < nvec; ++i) {
335-
const size_t global_idx = (aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment()));
336+
const size_t global_idx =
337+
(aligned ? (tid * nvec + i) : (tid * nvec + i - loader.alignment()));
336338
if (global_idx >= size) continue;
337339

338340
ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
@@ -466,19 +468,19 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out
466468
switch (align) {
467469
case Alignment::SAME_ALIGNED:
468470
unary_kernel<nvec, true, fp32, Param, OP><<<grid, threads, 0, stream>>>(
469-
input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements,
470-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
471+
input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
472+
first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
471473
break;
472474
case Alignment::SAME_UNALIGNED:
473475
unary_kernel<nvec, false, fp32, Param, OP><<<grid, threads, 0, stream>>>(
474-
input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements,
475-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
476+
input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
477+
first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
476478
break;
477479
case Alignment::DIFFERENT: {
478480
// If the pointers are aligned differently we cannot vectorize
479481
unary_kernel<1, true, fp32, Param, OP><<<grid, threads, 0, stream>>>(
480-
input, noop, output, scale, amax, scale_inv, params, N, N,
481-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
482+
input, noop, output, scale, amax, scale_inv, params, N, N, offsets, first_dims,
483+
last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
482484
break;
483485
}
484486
}
@@ -508,19 +510,19 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp
508510
switch (align) {
509511
case Alignment::SAME_ALIGNED:
510512
unary_grad_kernel<nvec, true, fp32, Param, OP><<<grid, threads, 0, stream>>>(
511-
grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements,
512-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
513+
grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
514+
first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
513515
break;
514516
case Alignment::SAME_UNALIGNED:
515517
unary_grad_kernel<nvec, false, fp32, Param, OP><<<grid, threads, 0, stream>>>(
516-
grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements,
517-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
518+
grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements, offsets,
519+
first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
518520
break;
519521
case Alignment::DIFFERENT: {
520522
// If the pointers are aligned differently we cannot vectorize
521523
unary_grad_kernel<1, true, fp32, Param, OP><<<grid, threads, 0, stream>>>(
522-
grad, input, output, scale, amax, scale_inv, params, N, N,
523-
offsets, first_dims, last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
524+
grad, input, output, scale, amax, scale_inv, params, N, N, offsets, first_dims,
525+
last_dims, num_tensors, scale_numel, scale_inv_numel, amax_numel);
524526
break;
525527
}
526528
}

0 commit comments

Comments
 (0)