Skip to content

Commit d2ee257

Browse files
chen2021673kilinchange
authored andcommitted
fix: correct batched matmul strides for bs=1 and integrate CTest into build pipeline
Set stride to 0 when batch_size is 1 to enable proper broadcasting in cuBLAS, and add configurable CTest execution after builds with googletest submodule.
1 parent ea99c0c commit d2ee257

5 files changed

Lines changed: 20 additions & 12 deletions

File tree

docs/test_usage_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
```bash
1010
mkdir build && cd build
11-
cmake -DBUILD_TEST=ON -DUSE_CUDA=ON ..
11+
cmake -DBUILD_TEST=ON -DUSE_CUDA=ON -DUSE_NCCL=ON ..
1212
make -j$(nproc)
1313
```
1414

infini_train/src/kernels/cuda/matmul.cu

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
6363
.alpha = 1.0f,
6464
.beta = 0.0f,
6565
.batch_count = static_cast<int>(bs),
66-
.stride_a = n * k,
67-
.stride_b = k * m,
68-
.stride_c = m * n,
66+
.stride_a = bs > 1 ? n * k : 0,
67+
.stride_b = bs > 1 ? k * m : 0,
68+
.stride_c = bs > 1 ? m * n : 0,
6969
.input_dtype = dtype,
7070
.output_dtype = dtype,
7171
});
@@ -133,9 +133,9 @@ std::shared_ptr<Tensor> MatmulBackwardInput(const std::shared_ptr<Tensor> &other
133133
.alpha = 1.0f,
134134
.beta = 0.0f,
135135
.batch_count = static_cast<int>(bs),
136-
.stride_a = k * n,
137-
.stride_b = n * m,
138-
.stride_c = m * k,
136+
.stride_a = bs > 1 ? k * n : 0,
137+
.stride_b = bs > 1 ? n * m : 0,
138+
.stride_c = bs > 1 ? m * k : 0,
139139
.input_dtype = compute_dtype,
140140
.output_dtype = output_dtype,
141141
});
@@ -202,9 +202,9 @@ std::shared_ptr<Tensor> MatmulBackwardOther(const std::shared_ptr<Tensor> &input
202202
.alpha = 1.0f,
203203
.beta = 0.0f,
204204
.batch_count = static_cast<int>(bs),
205-
.stride_a = n * m,
206-
.stride_b = k * m,
207-
.stride_c = n * k,
205+
.stride_a = bs > 1 ? n * m : 0,
206+
.stride_b = bs > 1 ? k * m : 0,
207+
.stride_c = bs > 1 ? n * k : 0,
208208
.input_dtype = compute_dtype,
209209
.output_dtype = output_dtype,
210210
});

scripts/run_models_and_profile.bash

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ BUILD_DIR="$(read_var BUILD_DIR)"; : "${BUILD_DIR:=../build}"
7070
LOG_DIR="$(read_var LOG_DIR)"; : "${LOG_DIR:=logs}"
7171
PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_logs}"
7272
COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}"
73+
RUN_CTEST="$(read_var RUN_CTEST)"; : "${RUN_CTEST:=true}"
74+
CTEST_CMD="$(read_var CTEST_CMD)"; : "${CTEST_CMD:=ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1}"
7375

7476
mkdir -p "$BUILD_DIR" "$LOG_DIR" "$PROFILE_LOG_DIR"
7577

@@ -244,6 +246,9 @@ for ((id=0; id<num_builds; ++id)); do
244246
# always clean before another build
245247
clean_build_dir
246248
run_and_log "$LAST_CMAKE_CMD" "${build_id}" "no" "build"
249+
if [[ "$RUN_CTEST" == "true" && "$build_profile" != "true" ]]; then
250+
run_and_log "$CTEST_CMD" "ctest_${build_id}" "no" "ctest"
251+
fi
247252

248253
# profile flag for runs
249254
profile_flag="no"

scripts/test_config.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
"LLAMA3_LLMC_FILEPATH": "/data/shared/InfiniTrain-dev/data/llmc/llama3/llama3.2_1B_fp32.bin",
88
"PROFILE_LOG_DIR": "./profile_logs",
99
"LOG_DIR": "./logs",
10-
"COMPARE_LOG_DIR": ""
10+
"COMPARE_LOG_DIR": "",
11+
"RUN_CTEST": "true",
12+
"CTEST_CMD": "ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1"
1113
},
1214
"builds": [
1315
{
1416
"id": "build_1",
1517
"profile": false,
16-
"cmd": "cmake -DUSE_CUDA=ON -DUSE_NCCL=ON .. && make -j"
18+
"cmd": "cmake -DBUILD_TEST=ON -DUSE_CUDA=ON -DUSE_NCCL=ON .. && make -j"
1719
},
1820
{
1921
"id": "build_2",

third_party/googletest

Submodule googletest added at f8d7d77

0 commit comments

Comments
 (0)