@@ -73,6 +73,12 @@ DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
7373DEFINE_bool (sequence_parallel, false , " Whether to enable Sequence Parallel" );
7474DEFINE_uint32 (pipeline_parallel, 1 , " Pipeline Parallel world size, specified the number of PP stages." );
7575DEFINE_uint32 (virtual_pipeline_parallel, 1 , " Number of chunks in PP stage." );
76+
77+ // activation recompute
78+ DEFINE_bool (activation_recompute, false , " Enable activation recompute to trade compute for memory." );
79+ DEFINE_string (recompute_granularity, " full" , " Activation recompute granularity: none|full|selective" );
80+ DEFINE_string (recompute_method, " none" , " Activation recompute method: none|uniform|block" );
81+ DEFINE_uint32 (recompute_num_layers, 0 , " Number of transformer layers per recompute region for uniform/block methods." );
7682// precision
7783DEFINE_string (dtype, " float32" , " precision used in training (float32/bfloat16)" );
7884// precision check
@@ -171,12 +177,16 @@ void Train(const nn::parallel::Rank &rank) {
171177
172178 nn::TransformerConfig model_config = llama3::LLaMA3Config ();
173179 std::shared_ptr<nn::Module> model = nullptr ;
180+ nn::SetActivationRecomputeConfig (&model_config, FLAGS_activation_recompute, FLAGS_recompute_granularity,
181+ FLAGS_recompute_method, static_cast <int64_t >(FLAGS_recompute_num_layers));
174182 if (!FLAGS_llmc_filepath.empty ()) {
175- model = llama3::LoadFromLLMC (FLAGS_llmc_filepath);
183+ model = llama3::LoadFromLLMC (FLAGS_llmc_filepath, model_config );
176184 } else {
177185 model = std::make_shared<nn::TransformerModel>(model_config);
178186 }
179187
188+ CHECK (model) << " LLaMA3 example expects LLaMA3 model." ;
189+
180190 model->To (device);
181191
182192 utils::PrecisionChecker::BuildNameMap (model.get ());
@@ -357,12 +367,20 @@ void Train(const nn::parallel::Rank &rank) {
357367 autocast_guard.Disable ();
358368
359369 LOG (INFO ) << " Rank " << rank.GlobalRank () << " : finish loss forward" ;
370+ auto [forward_used_mb, forward_reserved_mb] = impl->GetMemPoolPeakMB (device);
371+ LOG (INFO ) << std::format (
372+ " Rank {}: after forward (micro_step {}/{}), peak used: {:5d} MB | peak reserved: {:5d} MB" ,
373+ rank.GlobalRank (), micro_step + 1 , grad_accum_steps, forward_used_mb, forward_reserved_mb);
360374
361375 auto loss_cpu = loss->To (Device ());
362376 lossf += static_cast <const float *>(loss_cpu.DataPtr ())[0 ];
363377 LOG (INFO ) << " Rank " << rank.GlobalRank () << " : start backward" ;
364378 loss->Backward ();
365379 LOG (INFO ) << " Rank " << rank.GlobalRank () << " : finish backward" ;
380+ auto [backward_used_mb, backward_reserved_mb] = impl->GetMemPoolPeakMB (device);
381+ LOG (INFO ) << std::format (
382+ " Rank {}: after backward (micro_step {}/{}), peak used: {:5d} MB | peak reserved: {:5d} MB" ,
383+ rank.GlobalRank (), micro_step + 1 , grad_accum_steps, backward_used_mb, backward_reserved_mb);
366384 }
367385
368386 optimizer->Step ();
0 commit comments