Skip to content

Commit ef34fad

Browse files
committed
make FP8-M work also if first dim is not divisible by 512
1 parent 5f3bb5f commit ef34fad

1 file changed

Lines changed: 12 additions & 2 deletions

File tree

src/training/adamw_optimizer.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,27 @@ void AdamWStateManager::allocate_state(IModel& model, cudaStream_t stream, EAllo
177177
}
178178

179179
mBlocksMScales.resize(mConfig.NumLayers);
180+
180181
if(mMType == ETensorDType::FP8_E4M3) {
182+
auto prepare_shape_for_scales = [&](auto&& c) {
183+
// creates shards same as main weight
184+
auto sharded = shard_empty_container(flattened_view(c), mWorld);
185+
// flatten the local shard
186+
auto flattened = flattened_view(sharded);
187+
// and group into scaling groups
188+
auto grouped = shard_empty_container(std::move(flattened), 128);
189+
return grouped;
190+
};
181191
// we "shard" for 128 as many GPUs, so that we get 1 scale per 128 weights.
182192
for (int i = 0; i < mConfig.NumLayers; ++i) {
183-
mBlocksMScales[i] = shard_empty_container(model.create_block_container(mConfig, ETensorDType::FP32, ETensorDType::FP32), 128 * mWorld);
193+
mBlocksMScales[i] = prepare_shape_for_scales(model.create_block_container(mConfig, ETensorDType::FP32, ETensorDType::FP32));
184194
alloc_lazy.allocate(mBlocksMScales[i]);
185195
alloc_lazy.commit(alloc, EAllocationType::ON_DEVICE, "m_block_scales");
186196
visit([stream](Tensor& t){
187197
fill_constant(t, 1.f, t.nelem(), stream);
188198
}, mBlocksMScales[i]);
189199
}
190-
mNonBlockMScales = shard_empty_container(model.create_non_block_container(mConfig, ETensorDType::FP32, ETensorDType::FP32), 128 * mWorld);
200+
mNonBlockMScales = prepare_shape_for_scales(model.create_non_block_container(mConfig, ETensorDType::FP32, ETensorDType::FP32));
191201
alloc_lazy.allocate(mNonBlockMScales);
192202
alloc_lazy.commit(alloc, EAllocationType::ON_DEVICE, "m_nonblock_scales");
193203
visit([stream](Tensor& t){

0 commit comments

Comments
 (0)