Skip to content

Commit 6e66d46

Browse files
feat: init ac
1 parent b594867 commit 6e66d46

36 files changed

Lines changed: 1188 additions & 121 deletions

example/gpt2/checkpoint_loader.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,22 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
5252
return {}; // Unreachable, but keeps compiler happy
5353
}
5454
}
55+
56+
void ApplyRuntimeRecomputeConfig(nn::TransformerConfig *config, const nn::TransformerConfig &runtime_config) {
57+
config->recompute_granularity = runtime_config.recompute_granularity;
58+
config->recompute_method = runtime_config.recompute_method;
59+
config->recompute_num_layers = runtime_config.recompute_num_layers;
60+
}
5561
} // namespace
5662

5763
namespace gpt2 {
5864

5965
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) {
66+
return LoadFromLLMC(filepath, gpt2::GPT2Config());
67+
}
68+
69+
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath,
70+
const nn::TransformerConfig &runtime_config) {
6071
if (!std::filesystem::exists(filepath)) {
6172
LOG(FATAL) << "File not found: " << filepath;
6273
}
@@ -87,6 +98,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
8798
gpt2_config.n_layer = n_layer;
8899
gpt2_config.n_head = n_head;
89100
gpt2_config.n_embd = n_embd;
101+
ApplyRuntimeRecomputeConfig(&gpt2_config, runtime_config);
90102
auto local_gpt2 = std::make_shared<nn::TransformerModel>(gpt2_config);
91103

92104
LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size

example/gpt2/checkpoint_loader.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
#include <memory>
44
#include <string>
55

6+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
7+
68
namespace infini_train::nn {
79
class TransformerModel;
810
} // namespace infini_train::nn
911

1012
namespace gpt2 {
1113
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
14+
std::shared_ptr<infini_train::nn::TransformerModel>
15+
LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &runtime_config);
1216
} // namespace gpt2

example/gpt2/main.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
7575
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
7676
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
7777

78+
// activation recompute
79+
DEFINE_bool(activation_recompute, false, "Enable activation recompute to trade compute for memory.");
80+
DEFINE_string(recompute_granularity, "full", "Activation recompute granularity: none|full|selective");
81+
DEFINE_string(recompute_method, "none", "Activation recompute method: none|uniform|block");
82+
DEFINE_uint32(recompute_num_layers, 0, "Number of transformer layers per recompute region for uniform/block methods.");
83+
7884
// precision
7985
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
8086
// precision check
@@ -186,21 +192,24 @@ void Train(const nn::parallel::Rank &rank) {
186192
nn::TransformerConfig model_config = gpt2::GPT2Config();
187193
std::shared_ptr<nn::Module> model = nullptr;
188194

189-
if (!FLAGS_llmc_filepath.empty()) {
190-
model = gpt2::LoadFromLLMC(FLAGS_llmc_filepath);
191-
} else if (kModelToConfigs.count(FLAGS_model)) {
195+
if (FLAGS_llmc_filepath.empty() && kModelToConfigs.count(FLAGS_model)) {
192196
model_config = kModelToConfigs.at(FLAGS_model);
197+
}
198+
nn::SetActivationRecomputeConfig(&model_config, FLAGS_activation_recompute, FLAGS_recompute_granularity,
199+
FLAGS_recompute_method, static_cast<int64_t>(FLAGS_recompute_num_layers));
200+
201+
if (!FLAGS_llmc_filepath.empty()) {
202+
model = gpt2::LoadFromLLMC(FLAGS_llmc_filepath, model_config);
203+
} else {
193204
model = std::make_shared<nn::TransformerModel>(model_config);
194205
}
195206

207+
CHECK(model) << "GPT2 example expects GPT2 model.";
208+
196209
model->To(device);
197210

198211
utils::PrecisionChecker::BuildNameMap(model.get());
199212

200-
// Get chunk size before wrapping with LoRA (needed for PipelineParallel)
201-
auto gpt2_model = std::dynamic_pointer_cast<nn::TransformerModel>(model);
202-
CHECK(gpt2_model) << "GPT2 example expects GPT2 model.";
203-
204213
// Apply LoRA using GetLoRAModel (in-place injection)
205214
bool lora_enabled = FLAGS_lora_rank > 0;
206215
if (lora_enabled) {

example/llama3/checkpoint_loader.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,22 @@ static std::mt19937 gen{kRandomSeed};
3535
namespace {
3636
constexpr int32_t kLLaMA3Magic = 20240803;
3737
constexpr int32_t kLLaMA3FP32Version = 3;
38+
39+
void ApplyRuntimeRecomputeConfig(nn::TransformerConfig *config, const nn::TransformerConfig &runtime_config) {
40+
config->recompute_granularity = runtime_config.recompute_granularity;
41+
config->recompute_method = runtime_config.recompute_method;
42+
config->recompute_num_layers = runtime_config.recompute_num_layers;
43+
}
3844
} // namespace
3945

4046
namespace llama3 {
4147

4248
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath) {
49+
return LoadFromLLMC(filepath, llama3::LLaMA3Config());
50+
}
51+
52+
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath,
53+
const nn::TransformerConfig &runtime_config) {
4354
if (!std::filesystem::exists(filepath)) {
4455
LOG(FATAL) << "File not found: " << filepath;
4556
}
@@ -80,6 +91,7 @@ std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath)
8091
llama3_config.use_scaled_rope = static_cast<bool>(use_scaled_rope);
8192
llama3_config.norm_eps = norm_eps;
8293
llama3_config.max_gen_batch_size = max_gen_bs;
94+
ApplyRuntimeRecomputeConfig(&llama3_config, runtime_config);
8395
auto llama3 = std::make_shared<nn::TransformerModel>(llama3_config);
8496

8597
// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========

example/llama3/checkpoint_loader.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
#include <memory>
44
#include <string>
55

6+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
7+
68
namespace infini_train::nn {
79
class TransformerModel;
810
} // namespace infini_train::nn
911

1012
namespace llama3 {
1113
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
14+
std::shared_ptr<infini_train::nn::TransformerModel>
15+
LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &runtime_config);
1216
} // namespace llama3

example/llama3/main.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
7373
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
7474
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
7575
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
76+
77+
// activation recompute
78+
DEFINE_bool(activation_recompute, false, "Enable activation recompute to trade compute for memory.");
79+
DEFINE_string(recompute_granularity, "full", "Activation recompute granularity: none|full|selective");
80+
DEFINE_string(recompute_method, "none", "Activation recompute method: none|uniform|block");
81+
DEFINE_uint32(recompute_num_layers, 0, "Number of transformer layers per recompute region for uniform/block methods.");
7682
// precision
7783
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
7884
// precision check
@@ -171,12 +177,16 @@ void Train(const nn::parallel::Rank &rank) {
171177

172178
nn::TransformerConfig model_config = llama3::LLaMA3Config();
173179
std::shared_ptr<nn::Module> model = nullptr;
180+
nn::SetActivationRecomputeConfig(&model_config, FLAGS_activation_recompute, FLAGS_recompute_granularity,
181+
FLAGS_recompute_method, static_cast<int64_t>(FLAGS_recompute_num_layers));
174182
if (!FLAGS_llmc_filepath.empty()) {
175-
model = llama3::LoadFromLLMC(FLAGS_llmc_filepath);
183+
model = llama3::LoadFromLLMC(FLAGS_llmc_filepath, model_config);
176184
} else {
177185
model = std::make_shared<nn::TransformerModel>(model_config);
178186
}
179187

188+
CHECK(model) << "LLaMA3 example expects LLaMA3 model.";
189+
180190
model->To(device);
181191

182192
utils::PrecisionChecker::BuildNameMap(model.get());
@@ -357,12 +367,20 @@ void Train(const nn::parallel::Rank &rank) {
357367
autocast_guard.Disable();
358368

359369
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward";
370+
auto [forward_used_mb, forward_reserved_mb] = impl->GetMemPoolPeakMB(device);
371+
LOG(INFO) << std::format(
372+
"Rank {}: after forward (micro_step {}/{}), peak used: {:5d} MB | peak reserved: {:5d} MB",
373+
rank.GlobalRank(), micro_step + 1, grad_accum_steps, forward_used_mb, forward_reserved_mb);
360374

361375
auto loss_cpu = loss->To(Device());
362376
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
363377
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start backward";
364378
loss->Backward();
365379
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward";
380+
auto [backward_used_mb, backward_reserved_mb] = impl->GetMemPoolPeakMB(device);
381+
LOG(INFO) << std::format(
382+
"Rank {}: after backward (micro_step {}/{}), peak used: {:5d} MB | peak reserved: {:5d} MB",
383+
rank.GlobalRank(), micro_step + 1, grad_accum_steps, backward_used_mb, backward_reserved_mb);
366384
}
367385

368386
optimizer->Step();

example/mnist/main.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ int main(int argc, char *argv[]) {
6666
auto new_image = std::make_shared<Tensor>(image->To(device));
6767
auto new_label = std::make_shared<Tensor>(label->To(device));
6868

69-
auto outputs = network.Forward({new_image});
69+
auto outputs = network({new_image});
7070
optimizer.ZeroGrad();
7171

7272
auto loss = loss_fn.Forward({outputs[0], new_label});
@@ -101,7 +101,7 @@ int main(int argc, char *argv[]) {
101101
auto new_label = std::make_shared<Tensor>(label->To(device));
102102

103103
auto label_cpu = label->To(cpu_device);
104-
auto outputs = network.Forward({new_image});
104+
auto outputs = network({new_image});
105105
auto output_cpu = outputs[0]->To(cpu_device);
106106
auto loss = loss_fn.Forward({outputs[0], new_label});
107107
auto loss_cpu = loss[0]->To(cpu_device);

infini_train/include/autocast.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,24 @@ struct AutocastContext {
164164
// Global thread-local storage for autocast context
165165
inline thread_local AutocastContext tls_autocast_context;
166166

167+
// Lightweight snapshot for autocast state
168+
struct AutocastState {
169+
bool enabled = false;
170+
Device::DeviceType device_type = Device::DeviceType::kCPU;
171+
DataType autocast_dtype = DataType::kBFLOAT16;
172+
};
173+
174+
inline AutocastState GetAutocastState() {
175+
return AutocastState{tls_autocast_context.enabled, tls_autocast_context.device_type,
176+
tls_autocast_context.autocast_dtype};
177+
}
178+
179+
inline void SetAutocastState(const AutocastState &state) {
180+
tls_autocast_context.enabled = state.enabled;
181+
tls_autocast_context.device_type = state.device_type;
182+
tls_autocast_context.autocast_dtype = state.autocast_dtype;
183+
}
184+
167185
// RAII guard to enable/disable autocast in a scope
168186
class AutocastGuard {
169187
public:

infini_train/include/autograd/function.h

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <functional>
44
#include <memory>
5+
#include <optional>
56
#include <utility>
67
#include <vector>
78

@@ -21,6 +22,15 @@ class Function : public std::enable_shared_from_this<Function> {
2122
using FunctionPostHook = std::function<void(Function *, const std::vector<std::shared_ptr<Tensor>> &,
2223
const std::vector<std::shared_ptr<Tensor>> &)>;
2324

25+
// Definition of hooks for saved_tensors, in alignment with torch.autograd.graph.saved_tensors_hooks
26+
using SavedTensorPackHook = std::function<std::shared_ptr<void>(const std::shared_ptr<Tensor> &)>;
27+
using SavedTensorUnpackHook = std::function<std::shared_ptr<Tensor>(const std::shared_ptr<void> &)>;
28+
29+
struct SavedTensorHooks {
30+
SavedTensorPackHook pack;
31+
SavedTensorUnpackHook unpack;
32+
};
33+
2434
static constexpr char kUndefinedType[] = "Undefined";
2535

2636
Function() : type_(kUndefinedType) {}
@@ -34,7 +44,7 @@ class Function : public std::enable_shared_from_this<Function> {
3444
virtual std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) = 0;
3545

3646
std::vector<std::shared_ptr<Tensor>> Apply(const std::vector<std::shared_ptr<Tensor>> &input_tensors);
37-
virtual void BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int idx);
47+
virtual void BackwardPartial(std::shared_ptr<Tensor> grad_output, int idx);
3848

3949
void IncreaseDependenciesNumber();
4050

@@ -45,8 +55,33 @@ class Function : public std::enable_shared_from_this<Function> {
4555

4656
const std::string &type() const { return type_; }
4757

58+
void SaveForBackward(const std::vector<std::shared_ptr<Tensor>> &tensors);
59+
size_t SavedTensorsSize() const { return saved_tensors_.size(); }
60+
std::shared_ptr<Tensor> GetSavedTensor(size_t index) const;
61+
std::vector<std::shared_ptr<Tensor>> GetSavedTensors() const;
62+
63+
// RAII: Register pack/unpack hooks for saved_tensors, align with torch.autograd.graph.saved_tensors_hooks
64+
class SavedTensorHooksGuard {
65+
public:
66+
explicit SavedTensorHooksGuard(SavedTensorHooks hooks);
67+
~SavedTensorHooksGuard();
68+
69+
SavedTensorHooksGuard(const SavedTensorHooksGuard &) = delete;
70+
SavedTensorHooksGuard &operator=(const SavedTensorHooksGuard &) = delete;
71+
72+
private:
73+
size_t depth_ = 0;
74+
};
75+
4876
protected:
49-
std::vector<std::shared_ptr<Tensor>> saved_tensors_;
77+
struct SavedTensorEntry {
78+
// Tensor itself, used under default or reentrant version of recomputation
79+
std::shared_ptr<Tensor> tensor;
80+
// Function to recompute the target tensor, used under non-reentrant version of recomputation
81+
std::shared_ptr<void> hook_state;
82+
SavedTensorUnpackHook unpack;
83+
};
84+
std::vector<SavedTensorEntry> saved_tensors_;
5085
std::vector<bool> needs_input_grad_;
5186

5287
private:

infini_train/include/autograd/grad_mode.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ class GradMode {
88
// Whether to enable Autograd (enabled by default)
99
static bool IsEnabled() { return grad_enabled_; }
1010
static void SetEnabled(bool enabled) { grad_enabled_ = enabled; }
11+
static bool PropagateRequiresGrad() { return propagate_requires_grad_; }
12+
static void SetPropagateRequiresGrad(bool enabled) { propagate_requires_grad_ = enabled; }
1113

1214
private:
1315
// grad mode should be thread_local
1416
static thread_local bool grad_enabled_;
17+
static thread_local bool propagate_requires_grad_;
1518
};
1619

1720
// RAII: Disable grad (align with torch.no_grad)
@@ -34,4 +37,19 @@ class EnableGradGuard {
3437
bool prev_;
3538
};
3639

40+
// RAII: Propagate requires_grad metadata while graph construction is disabled.
41+
// Used by non-reentrant checkpoint recomputation so downstream SetupContext
42+
// calls see the same needs_input_grad_ pattern as the original forward,
43+
// without wiring the recompute graph into the engine.
44+
class PropagateRequiresGradGuard {
45+
public:
46+
PropagateRequiresGradGuard() : prev_(GradMode::PropagateRequiresGrad()) {
47+
GradMode::SetPropagateRequiresGrad(true);
48+
}
49+
~PropagateRequiresGradGuard() { GradMode::SetPropagateRequiresGrad(prev_); }
50+
51+
private:
52+
bool prev_;
53+
};
54+
3755
} // namespace infini_train::autograd

0 commit comments

Comments
 (0)