Skip to content

Commit a7e1b99

Browse files
committed
refactor: fix Matmul nullptr safety and convert GemmParams/SgemvParams to designated initializers
- Save input1_dims_/input2_dims_ in Matmul::SetupContext to avoid Dims() calls on potentially-null saved tensors in Backward - Get device from grad_output instead of input1 in Matmul::Backward - Add CHECK guards before dereferencing nullable saved tensors - Convert all GemmParams/SgemvParams construction in linear.cu, matmul.cu, outer.cu to C++20 designated initializer form
1 parent 252e6cd commit a7e1b99

5 files changed

Lines changed: 233 additions & 234 deletions

File tree

infini_train/include/autograd/matmul.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,7 @@ class Matmul : public Function {
2323

2424
private:
2525
int64_t out_features_ = 0;
26+
std::vector<int64_t> input1_dims_;
27+
std::vector<int64_t> input2_dims_;
2628
};
2729
} // namespace infini_train::autograd

infini_train/src/autograd/matmul.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ void Matmul::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
4242
};
4343

4444
saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr};
45+
input1_dims_ = input1->Dims();
46+
input2_dims_ = input2->Dims();
4547
out_features_ = output->Dims()[0];
4648
}
4749

@@ -56,18 +58,20 @@ std::vector<std::shared_ptr<Tensor>> Matmul::Backward(const std::vector<std::sha
5658
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
5759
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];
5860

59-
auto device = input1->GetDevice().type();
61+
auto device = grad_output->GetDevice().type();
6062

6163
std::shared_ptr<Tensor> grad_input = nullptr;
6264
std::shared_ptr<Tensor> grad_other = nullptr;
6365

6466
if (need_grad_input1) {
67+
CHECK(input2 != nullptr) << "input2 not saved but need_grad_input1 is true";
6568
grad_input = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardInput"}, input2,
66-
grad_output, input1->Dims());
69+
grad_output, input1_dims_);
6770
}
6871
if (need_grad_input2) {
72+
CHECK(input1 != nullptr) << "input1 not saved but need_grad_input2 is true";
6973
grad_other = Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "MatmulBackwardOther"}, input1,
70-
grad_output, input2->Dims());
74+
grad_output, input2_dims_);
7175
}
7276

7377
return {grad_input, grad_other};

infini_train/src/kernels/cuda/linear.cu

Lines changed: 81 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,18 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
8080
// When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector).
8181
// cublasSgemv does not support bf16, so bf16 falls through to GemmCuda.
8282
if (bs == 1 && dtype == DataType::kFLOAT32) {
83-
SgemvParams p;
84-
p.trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
85-
p.m = static_cast<int>(transpose ? in_features : out_features);
86-
p.n = static_cast<int>(transpose ? out_features : in_features);
87-
p.A = static_cast<const float *>(weight->DataPtr());
88-
p.lda = static_cast<int>(transpose ? in_features : out_features);
89-
p.x = static_cast<const float *>(input->DataPtr());
90-
p.y = static_cast<float *>(output->DataPtr());
91-
p.alpha = 1.0f;
92-
p.beta = 1.0f; // output already initialized with bias or zero above
93-
p.blas_handle = GetCublasHandle(device);
94-
SgemvCuda(p);
83+
SgemvCuda(SgemvParams{
84+
.trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N,
85+
.m = static_cast<int>(transpose ? in_features : out_features),
86+
.n = static_cast<int>(transpose ? out_features : in_features),
87+
.A = static_cast<const float *>(weight->DataPtr()),
88+
.lda = static_cast<int>(transpose ? in_features : out_features),
89+
.x = static_cast<const float *>(input->DataPtr()),
90+
.y = static_cast<float *>(output->DataPtr()),
91+
.alpha = 1.0f,
92+
.beta = 1.0f, // output already initialized with bias or zero above
93+
.blas_handle = GetCublasHandle(device),
94+
});
9595
} else {
9696
// cuBLAS is colmun-major
9797
// - if a is transposed:
@@ -106,26 +106,25 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
106106
// C = output.T[out_features, bs]
107107
// A = weight.T[out_features, in_features]
108108
// B = input.T[in_features, bs]
109-
GemmParams p;
110-
p.trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
111-
p.trans_b = CUBLAS_OP_N;
112-
p.m = static_cast<int>(out_features);
113-
p.n = static_cast<int>(bs);
114-
p.k = static_cast<int>(in_features);
115-
p.A = weight->DataPtr();
116-
p.lda = static_cast<int>(transpose ? in_features : out_features);
117-
p.B = input->DataPtr();
118-
p.ldb = static_cast<int>(in_features);
119-
p.C = output->DataPtr();
120-
p.ldc = static_cast<int>(out_features);
121-
p.alpha = 1.0f;
122-
p.beta = 1.0f; // bias already written into output; beta=1 accumulates
123-
p.batch_count = 1;
124-
p.input_dtype = dtype;
125-
p.output_dtype = dtype;
126-
p.blas_handle = GetCublasHandle(device);
127-
128-
GemmCuda(p);
109+
GemmCuda(GemmParams{
110+
.trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N,
111+
.trans_b = CUBLAS_OP_N,
112+
.m = static_cast<int>(out_features),
113+
.n = static_cast<int>(bs),
114+
.k = static_cast<int>(in_features),
115+
.A = weight->DataPtr(),
116+
.lda = static_cast<int>(transpose ? in_features : out_features),
117+
.B = input->DataPtr(),
118+
.ldb = static_cast<int>(in_features),
119+
.C = output->DataPtr(),
120+
.ldc = static_cast<int>(out_features),
121+
.alpha = 1.0f,
122+
.beta = 1.0f, // bias already written into output; beta=1 accumulates
123+
.batch_count = 1,
124+
.input_dtype = dtype,
125+
.output_dtype = dtype,
126+
.blas_handle = GetCublasHandle(device),
127+
});
129128
}
130129

131130
return output;
@@ -172,18 +171,18 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
172171
// When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector).
173172
// cublasSgemv does not support bf16, so bf16 falls through to GemmCuda.
174173
if (bs == 1 && compute_dtype == DataType::kFLOAT32) {
175-
SgemvParams p;
176-
p.trans = transpose ? CUBLAS_OP_N : CUBLAS_OP_T;
177-
p.m = static_cast<int>(transpose ? in_features : out_features);
178-
p.n = static_cast<int>(transpose ? out_features : in_features);
179-
p.A = static_cast<const float *>(weight->DataPtr());
180-
p.lda = static_cast<int>(transpose ? in_features : out_features);
181-
p.x = static_cast<const float *>(grad_output_promoted->DataPtr());
182-
p.y = static_cast<float *>(grad_input->DataPtr());
183-
p.alpha = 1.0f;
184-
p.beta = 0.0f;
185-
p.blas_handle = GetCublasHandle(grad_output->GetDevice());
186-
SgemvCuda(p);
174+
SgemvCuda(SgemvParams{
175+
.trans = transpose ? CUBLAS_OP_N : CUBLAS_OP_T,
176+
.m = static_cast<int>(transpose ? in_features : out_features),
177+
.n = static_cast<int>(transpose ? out_features : in_features),
178+
.A = static_cast<const float *>(weight->DataPtr()),
179+
.lda = static_cast<int>(transpose ? in_features : out_features),
180+
.x = static_cast<const float *>(grad_output_promoted->DataPtr()),
181+
.y = static_cast<float *>(grad_input->DataPtr()),
182+
.alpha = 1.0f,
183+
.beta = 0.0f,
184+
.blas_handle = GetCublasHandle(grad_output->GetDevice()),
185+
});
187186
} else {
188187
// - if transpose:
189188
// weight is [out_features, in_features] here
@@ -198,26 +197,25 @@ std::shared_ptr<Tensor> LinearBackwardInput(const std::shared_ptr<Tensor> &weigh
198197
// C = d_input.T[in_features, bs]
199198
// A = weight.T[out_features, in_features]
200199
// B = d_output.T[out_features, bs]
201-
GemmParams p;
202-
p.trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T;
203-
p.trans_b = CUBLAS_OP_N;
204-
p.m = static_cast<int>(in_features);
205-
p.n = static_cast<int>(bs);
206-
p.k = static_cast<int>(out_features);
207-
p.A = weight->DataPtr();
208-
p.lda = static_cast<int>(transpose ? in_features : out_features);
209-
p.B = grad_output_promoted->DataPtr();
210-
p.ldb = static_cast<int>(out_features);
211-
p.C = grad_input->DataPtr();
212-
p.ldc = static_cast<int>(in_features);
213-
p.alpha = 1.0f;
214-
p.beta = 0.0f;
215-
p.batch_count = 1;
216-
p.input_dtype = compute_dtype;
217-
p.output_dtype = output_dtype;
218-
p.blas_handle = GetCublasHandle(grad_output->GetDevice());
219-
220-
GemmCuda(p);
200+
GemmCuda(GemmParams{
201+
.trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T,
202+
.trans_b = CUBLAS_OP_N,
203+
.m = static_cast<int>(in_features),
204+
.n = static_cast<int>(bs),
205+
.k = static_cast<int>(out_features),
206+
.A = weight->DataPtr(),
207+
.lda = static_cast<int>(transpose ? in_features : out_features),
208+
.B = grad_output_promoted->DataPtr(),
209+
.ldb = static_cast<int>(out_features),
210+
.C = grad_input->DataPtr(),
211+
.ldc = static_cast<int>(in_features),
212+
.alpha = 1.0f,
213+
.beta = 0.0f,
214+
.batch_count = 1,
215+
.input_dtype = compute_dtype,
216+
.output_dtype = output_dtype,
217+
.blas_handle = GetCublasHandle(grad_output->GetDevice()),
218+
});
221219
}
222220

223221
return grad_input;
@@ -259,26 +257,25 @@ std::shared_ptr<Tensor> LinearBackwardWeight(const std::shared_ptr<Tensor> &inpu
259257
const int lda = static_cast<int>(transpose ? in_features : out_features);
260258
const int ldb = static_cast<int>(transpose ? out_features : in_features);
261259

262-
GemmParams p;
263-
p.trans_a = CUBLAS_OP_N;
264-
p.trans_b = CUBLAS_OP_T;
265-
p.m = static_cast<int>(transpose ? in_features : out_features);
266-
p.n = static_cast<int>(transpose ? out_features : in_features);
267-
p.k = static_cast<int>(bs);
268-
p.A = a;
269-
p.lda = lda;
270-
p.B = b;
271-
p.ldb = ldb;
272-
p.C = grad_weight->DataPtr();
273-
p.ldc = static_cast<int>(transpose ? in_features : out_features);
274-
p.alpha = 1.0f;
275-
p.beta = 0.0f;
276-
p.batch_count = 1;
277-
p.input_dtype = compute_dtype;
278-
p.output_dtype = output_dtype;
279-
p.blas_handle = GetCublasHandle(grad_output->GetDevice());
280-
281-
GemmCuda(p);
260+
GemmCuda(GemmParams{
261+
.trans_a = CUBLAS_OP_N,
262+
.trans_b = CUBLAS_OP_T,
263+
.m = static_cast<int>(transpose ? in_features : out_features),
264+
.n = static_cast<int>(transpose ? out_features : in_features),
265+
.k = static_cast<int>(bs),
266+
.A = a,
267+
.lda = lda,
268+
.B = b,
269+
.ldb = ldb,
270+
.C = grad_weight->DataPtr(),
271+
.ldc = static_cast<int>(transpose ? in_features : out_features),
272+
.alpha = 1.0f,
273+
.beta = 0.0f,
274+
.batch_count = 1,
275+
.input_dtype = compute_dtype,
276+
.output_dtype = output_dtype,
277+
.blas_handle = GetCublasHandle(grad_output->GetDevice()),
278+
});
282279

283280
return grad_weight;
284281
}

infini_train/src/kernels/cuda/matmul.cu

Lines changed: 66 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -47,29 +47,28 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
4747
// A = other.T[*, n, k]
4848
// B = input.T[*, k, m]
4949
// NOTE(zbl): the last cublasGemmAlgo_t param has no effect on GPU arch >= sm_80(Ampere)
50-
GemmParams p;
51-
p.trans_a = CUBLAS_OP_N;
52-
p.trans_b = CUBLAS_OP_N;
53-
p.m = static_cast<int>(n);
54-
p.n = static_cast<int>(m);
55-
p.k = static_cast<int>(k);
56-
p.A = other->DataPtr();
57-
p.lda = static_cast<int>(n);
58-
p.stride_a = n * k;
59-
p.B = input->DataPtr();
60-
p.ldb = static_cast<int>(k);
61-
p.stride_b = k * m;
62-
p.C = output->DataPtr();
63-
p.ldc = static_cast<int>(n);
64-
p.stride_c = m * n;
65-
p.alpha = 1.0f;
66-
p.beta = 0.0f;
67-
p.batch_count = static_cast<int>(bs);
68-
p.input_dtype = dtype;
69-
p.output_dtype = dtype;
70-
p.blas_handle = GetCublasHandle(device);
71-
72-
GemmCuda(p);
50+
GemmCuda(GemmParams{
51+
.trans_a = CUBLAS_OP_N,
52+
.trans_b = CUBLAS_OP_N,
53+
.m = static_cast<int>(n),
54+
.n = static_cast<int>(m),
55+
.k = static_cast<int>(k),
56+
.A = other->DataPtr(),
57+
.lda = static_cast<int>(n),
58+
.B = input->DataPtr(),
59+
.ldb = static_cast<int>(k),
60+
.C = output->DataPtr(),
61+
.ldc = static_cast<int>(n),
62+
.alpha = 1.0f,
63+
.beta = 0.0f,
64+
.batch_count = static_cast<int>(bs),
65+
.stride_a = n * k,
66+
.stride_b = k * m,
67+
.stride_c = m * n,
68+
.input_dtype = dtype,
69+
.output_dtype = dtype,
70+
.blas_handle = GetCublasHandle(device),
71+
});
7372

7473
return output;
7574
}
@@ -119,29 +118,28 @@ std::shared_ptr<Tensor> MatmulBackwardInput(const std::shared_ptr<Tensor> &other
119118
// C = grad_input.T[*, k, m]
120119
// A = other.T[*, n, k]
121120
// B = grad_output.T[*, n, m]
122-
GemmParams p;
123-
p.trans_a = CUBLAS_OP_T;
124-
p.trans_b = CUBLAS_OP_N;
125-
p.m = static_cast<int>(k);
126-
p.n = static_cast<int>(m);
127-
p.k = static_cast<int>(n);
128-
p.A = other->DataPtr();
129-
p.lda = static_cast<int>(n);
130-
p.stride_a = k * n;
131-
p.B = grad_output_promoted->DataPtr();
132-
p.ldb = static_cast<int>(n);
133-
p.stride_b = n * m;
134-
p.C = grad_input->DataPtr();
135-
p.ldc = static_cast<int>(k);
136-
p.stride_c = m * k;
137-
p.alpha = 1.0f;
138-
p.beta = 0.0f;
139-
p.batch_count = static_cast<int>(bs);
140-
p.input_dtype = compute_dtype;
141-
p.output_dtype = output_dtype;
142-
p.blas_handle = GetCublasHandle(device);
143-
144-
GemmCuda(p);
121+
GemmCuda(GemmParams{
122+
.trans_a = CUBLAS_OP_T,
123+
.trans_b = CUBLAS_OP_N,
124+
.m = static_cast<int>(k),
125+
.n = static_cast<int>(m),
126+
.k = static_cast<int>(n),
127+
.A = other->DataPtr(),
128+
.lda = static_cast<int>(n),
129+
.B = grad_output_promoted->DataPtr(),
130+
.ldb = static_cast<int>(n),
131+
.C = grad_input->DataPtr(),
132+
.ldc = static_cast<int>(k),
133+
.alpha = 1.0f,
134+
.beta = 0.0f,
135+
.batch_count = static_cast<int>(bs),
136+
.stride_a = k * n,
137+
.stride_b = n * m,
138+
.stride_c = m * k,
139+
.input_dtype = compute_dtype,
140+
.output_dtype = output_dtype,
141+
.blas_handle = GetCublasHandle(device),
142+
});
145143

146144
return grad_input;
147145
}
@@ -190,29 +188,28 @@ std::shared_ptr<Tensor> MatmulBackwardOther(const std::shared_ptr<Tensor> &input
190188
// C = grad_other.T[*, n, k]
191189
// A = grad_output.T[*, n, m]
192190
// B = input.T[*, k, m]
193-
GemmParams p;
194-
p.trans_a = CUBLAS_OP_N;
195-
p.trans_b = CUBLAS_OP_T;
196-
p.m = static_cast<int>(n);
197-
p.n = static_cast<int>(k);
198-
p.k = static_cast<int>(m);
199-
p.A = grad_output_promoted->DataPtr();
200-
p.lda = static_cast<int>(n);
201-
p.stride_a = n * m;
202-
p.B = input1->DataPtr();
203-
p.ldb = static_cast<int>(k);
204-
p.stride_b = k * m;
205-
p.C = grad_other->DataPtr();
206-
p.ldc = static_cast<int>(n);
207-
p.stride_c = n * k;
208-
p.alpha = 1.0f;
209-
p.beta = 0.0f;
210-
p.batch_count = static_cast<int>(bs);
211-
p.input_dtype = compute_dtype;
212-
p.output_dtype = output_dtype;
213-
p.blas_handle = GetCublasHandle(device);
214-
215-
GemmCuda(p);
191+
GemmCuda(GemmParams{
192+
.trans_a = CUBLAS_OP_N,
193+
.trans_b = CUBLAS_OP_T,
194+
.m = static_cast<int>(n),
195+
.n = static_cast<int>(k),
196+
.k = static_cast<int>(m),
197+
.A = grad_output_promoted->DataPtr(),
198+
.lda = static_cast<int>(n),
199+
.B = input1->DataPtr(),
200+
.ldb = static_cast<int>(k),
201+
.C = grad_other->DataPtr(),
202+
.ldc = static_cast<int>(n),
203+
.alpha = 1.0f,
204+
.beta = 0.0f,
205+
.batch_count = static_cast<int>(bs),
206+
.stride_a = n * m,
207+
.stride_b = k * m,
208+
.stride_c = n * k,
209+
.input_dtype = compute_dtype,
210+
.output_dtype = output_dtype,
211+
.blas_handle = GetCublasHandle(device),
212+
});
216213

217214
return grad_other;
218215
}

0 commit comments

Comments
 (0)