Skip to content

Commit 868a124

Browse files
committed
fix bug when using deepspeedzero3 getting embedding size
1 parent 6c09c64 commit 868a124

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

src/lmflow/models/hf_decoder_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ def __init__(
248248
# We resize the embeddings only when necessary to avoid index errors.
249249
# If you are creating a model from scratch on a small vocab and want a
250250
# smaller embedding size, remove this test.
251-
embedding_size = model.get_input_embeddings().weight.shape[0]
251+
with deepspeed.zero.GatheredParameters(model.get_input_embeddings().weight, modifier_rank=None):
252+
weights = model.get_input_embeddings().weight
253+
embedding_size = weights.shape[0]
252254
if len(tokenizer) > embedding_size:
253255
model.resize_token_embeddings(len(tokenizer))
254256

0 commit comments

Comments
 (0)