@@ -177,17 +177,27 @@ void AdamWStateManager::allocate_state(IModel& model, cudaStream_t stream, EAllo
177177 }
178178
179179 mBlocksMScales .resize (mConfig .NumLayers );
180+
180181 if (mMType == ETensorDType::FP8_E4M3 ) {
182+ auto prepare_shape_for_scales = [&](auto && c) {
183+ // creates shards same as main weight
184+ auto sharded = shard_empty_container (flattened_view (c), mWorld );
185+ // flatten the local shard
186+ auto flattened = flattened_view (sharded);
187+ // and group into scaling groups
188+ auto grouped = shard_empty_container (std::move (flattened), 128 );
189+ return grouped;
190+ };
181191 // we "shard" for 128 as many GPUs, so that we get 1 scale per 128 weights.
182192 for (int i = 0 ; i < mConfig .NumLayers ; ++i) {
183- mBlocksMScales [i] = shard_empty_container (model.create_block_container (mConfig , ETensorDType::FP32 , ETensorDType::FP32 ), 128 * mWorld );
193+ mBlocksMScales [i] = prepare_shape_for_scales (model.create_block_container (mConfig , ETensorDType::FP32 , ETensorDType::FP32 ));
184194 alloc_lazy.allocate (mBlocksMScales [i]);
185195 alloc_lazy.commit (alloc, EAllocationType::ON_DEVICE , " m_block_scales" );
186196 visit ([stream](Tensor& t){
187197 fill_constant (t, 1 .f , t.nelem (), stream);
188198 }, mBlocksMScales [i]);
189199 }
190- mNonBlockMScales = shard_empty_container (model.create_non_block_container (mConfig , ETensorDType::FP32 , ETensorDType::FP32 ), 128 * mWorld );
200+ mNonBlockMScales = prepare_shape_for_scales (model.create_non_block_container (mConfig , ETensorDType::FP32 , ETensorDType::FP32 ));
191201 alloc_lazy.allocate (mNonBlockMScales );
192202 alloc_lazy.commit (alloc, EAllocationType::ON_DEVICE , " m_nonblock_scales" );
193203 visit ([stream](Tensor& t){
0 commit comments