Skip to content

Commit 80ab846

Browse files
committed
refactor(autocast): move autocast from dispatcher to autograd boundary
Apply autocast in Function::Apply before Forward/SetupContext so saved tensors are stored in the actual forward compute dtype, instead of guessing it from output->Dtype() inside SetupContext. Drops the duplicate cast workaround in Matmul/Linear::SetupContext. Add AutocastByName entry on AutocastContext (keyed by op name string) and extend GetBaseOpName to strip the "Function" suffix used by autograd::Function::type_. Remove the now-redundant autocast hook from Dispatcher::Call; backward kernels and internal helpers are no longer accidentally re-cast. Add direct common.h includes to elementwise/gather kernels that previously relied on the dispatcher.h -> autocast.h -> common.h transitive include chain.
1 parent 49cb34b commit 80ab846

8 files changed

Lines changed: 22 additions & 63 deletions

File tree

infini_train/include/autocast.h

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,10 @@
1111
namespace infini_train {
1212
namespace {
1313
inline std::string_view GetBaseOpName(std::string_view op) {
14-
constexpr std::string_view forward_suffix = "Forward";
15-
constexpr std::string_view backward_suffix = "Backward";
16-
17-
// Check for "Forward" suffix
18-
if (op.size() >= forward_suffix.size()) {
19-
const auto suffix_pos = op.size() - forward_suffix.size();
20-
if (op.substr(suffix_pos) == forward_suffix) {
21-
return op.substr(0, suffix_pos);
22-
}
23-
}
24-
25-
// Check for "Backward" suffix
26-
if (op.size() >= backward_suffix.size()) {
27-
const auto suffix_pos = op.size() - backward_suffix.size();
28-
if (op.substr(suffix_pos) == backward_suffix) {
29-
return op.substr(0, suffix_pos);
30-
}
14+
constexpr std::string_view function_suffix = "Function";
15+
if (op.size() >= function_suffix.size() && op.substr(op.size() - function_suffix.size()) == function_suffix) {
16+
return op.substr(0, op.size() - function_suffix.size());
3117
}
32-
3318
return op;
3419
}
3520
}; // namespace
@@ -97,18 +82,14 @@ struct AutocastContext {
9782
Device::DeviceType device_type = Device::DeviceType::kCPU; // Target device type (CPU/GPU)
9883
DataType autocast_dtype = DataType::kBFLOAT16; // The data type used for autocasting
9984

100-
template <typename... ArgsT> void Autocast(std::pair<Device::DeviceType, std::string> key, ArgsT &...args) {
85+
// Cast a parameter pack of tensors (or shared_ptr<Tensor>) according to the cast policy
86+
// associated with `op_name`. Called from autograd::Function::Apply with type_ as op_name.
87+
template <typename... ArgsT> void Autocast(std::string_view op_name, ArgsT &...args) {
10188
if (!enabled) {
10289
return;
10390
}
10491

105-
if (device_type != key.first) {
106-
LOG_LOC(FATAL, "In AutocastContext::Autocast(): the AutocastContext device_type is different from the one "
107-
"passed in. Don't know what to do.");
108-
return;
109-
}
110-
111-
auto map_it = kOpCastPolicyMap.find(GetBaseOpName(key.second));
92+
auto map_it = kOpCastPolicyMap.find(GetBaseOpName(op_name));
11293
if (map_it == kOpCastPolicyMap.end()) {
11394
return;
11495
}
@@ -132,7 +113,6 @@ struct AutocastContext {
132113
}
133114
};
134115

135-
// Process each argument
136116
auto cast_arg = [&](auto &arg) {
137117
using T = std::decay_t<decltype(arg)>;
138118
if constexpr (std::is_same_v<T, std::shared_ptr<Tensor>>) {
@@ -156,7 +136,6 @@ struct AutocastContext {
156136
}
157137
};
158138

159-
// Apply casting to each argument
160139
(cast_arg(args), ...);
161140
}
162141
};

infini_train/include/dispatcher.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
#include "glog/logging.h"
88

9-
#include "infini_train/include/autocast.h"
109
#include "infini_train/include/device.h"
1110
#ifdef PROFILE_MODE
1211
#include "infini_train/include/profiler.h"
@@ -73,7 +72,6 @@ class Dispatcher {
7372

7473
template <typename RetT, class... ArgsT> RetT Call(KeyT key, ArgsT... args) const {
7574
auto kernel = this->GetKernel(key);
76-
tls_autocast_context.Autocast(key, args...);
7775
#ifdef PROFILE_MODE
7876
SetProfileContext(key.second, key.first);
7977
#endif

infini_train/src/autograd/function.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "glog/logging.h"
44

5+
#include "infini_train/include/autocast.h"
56
#include "infini_train/include/autograd/accumulate.h"
67
#include "infini_train/include/autograd/function_hook.h"
78
#include "infini_train/include/autograd/grad_mode.h"
@@ -46,12 +47,19 @@ std::vector<std::shared_ptr<Tensor>> Function::Apply(const std::vector<std::shar
4647
}
4748
}
4849

50+
// Apply autocast once at the autograd boundary so Forward / SetupContext receive
51+
// tensors already in the compute dtype. The shared_ptr copies are local; we keep
52+
// the caller's `input_tensors` untouched so next_functions_ wires up to the
53+
// original autograd graph (leaf -> AccumulateGrad / non-leaf -> grad_fn).
54+
auto compute_inputs = input_tensors;
55+
for (auto &t : compute_inputs) { tls_autocast_context.Autocast(type_, t); }
56+
4957
std::vector<std::shared_ptr<Tensor>> output_tensors;
5058
{
5159
autograd::NoGradGuard no_grad;
5260
// no_grad in autograd.Function.Forward()
53-
output_tensors = Forward(input_tensors);
54-
SetupContext(input_tensors, output_tensors);
61+
output_tensors = Forward(compute_inputs);
62+
SetupContext(compute_inputs, output_tensors);
5563
}
5664

5765
// Call forward post-hooks

infini_train/src/autograd/linear.cc

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,12 @@ void Linear::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
2020
const std::vector<std::shared_ptr<Tensor>> &output_tensors) {
2121
const auto &input = input_tensors[0];
2222
const auto &weight = input_tensors[1];
23-
// Cast saved tensors to forward compute dtype (output dtype) so backward
24-
// computes in the same precision as forward, matching PyTorch's behavior.
2523

26-
// FIXME: An extra cast (input/weight -> compute_dtype) is performed here because
27-
// autocast runs before autograd. The correct approach is to adjust the ordering or
28-
// integration of autocast and autograd so that autograd receives already-cast tensors,
29-
// avoiding the redundant cast.
30-
31-
// FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be
32-
// determined by autocast, not derived from output_tensors[0]->Dtype().
33-
auto compute_dtype = output_tensors[0]->Dtype();
3424
bool need_input = needs_input_grad_.size() > 0 && needs_input_grad_[0];
3525
bool need_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1];
3626

37-
auto cast = [&](const std::shared_ptr<Tensor> &t) {
38-
return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype));
39-
};
40-
4127
// grad_input needs weight, grad_weight needs input
42-
saved_tensors_ = {need_weight ? cast(input) : nullptr, need_input ? cast(weight) : nullptr};
28+
saved_tensors_ = {need_weight ? input : nullptr, need_input ? weight : nullptr};
4329

4430
transpose_ = true;
4531
bias_ = input_tensors.size() == 3;

infini_train/src/autograd/matmul.cc

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,13 @@ void Matmul::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
2020
const auto &input1 = input_tensors[0];
2121
const auto &input2 = input_tensors[1];
2222
const auto &output = output_tensors[0];
23-
// Cast saved tensors to forward compute dtype (output dtype) so backward
24-
// computes in the same precision as forward, matching PyTorch's behavior.
25-
26-
// FIXME: An extra cast (input1/input2 -> compute_dtype) is performed here because
27-
// autocast runs before autograd. The correct approach is to adjust the ordering or
28-
// integration of autocast and autograd so that autograd receives already-cast tensors,
29-
// avoiding the redundant cast.
30-
31-
// FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be
32-
// determined by autocast, not derived from output->Dtype().
33-
auto compute_dtype = output->Dtype();
3423

3524
// grad_input1 = grad_output @ input2^T, so input2 is needed
3625
// grad_input2 = grad_output^T @ input1, so input1 is needed
3726
bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0];
3827
bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1];
3928

40-
auto cast = [&](const std::shared_ptr<Tensor> &t) {
41-
return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype));
42-
};
43-
44-
saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr};
29+
saved_tensors_ = {need_grad_input2 ? input1 : nullptr, need_grad_input1 ? input2 : nullptr};
4530
input1_dims_ = input1->Dims();
4631
input2_dims_ = input2->Dims();
4732
out_features_ = output->Dims()[0];

infini_train/src/kernels/cpu/elementwise.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "glog/logging.h"
88

9+
#include "infini_train/include/common/common.h"
910
#include "infini_train/include/device.h"
1011
#include "infini_train/include/dispatcher.h"
1112
#include "infini_train/include/tensor.h"

infini_train/src/kernels/cuda/elementwise.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <cub/warp/warp_reduce.cuh>
44

5+
#include "infini_train/include/common/common.h"
56
#include "infini_train/include/common/cuda/common_cuda.h"
67
#include "infini_train/include/common/cuda/kernel_helper.cuh"
78
#include "infini_train/include/core/runtime/device_guard.h"

infini_train/src/kernels/cuda/gather.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "glog/logging.h"
22

3+
#include "infini_train/include/common/common.h"
34
#include "infini_train/include/common/cuda/common_cuda.h"
45
#include "infini_train/include/core/runtime/device_guard.h"
56
#include "infini_train/include/dispatcher.h"

0 commit comments

Comments
 (0)