We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6c09c64 commit 868a124Copy full SHA for 868a124
1 file changed
src/lmflow/models/hf_decoder_model.py
@@ -248,7 +248,9 @@ def __init__(
248
# We resize the embeddings only when necessary to avoid index errors.
249
# If you are creating a model from scratch on a small vocab and want a
250
# smaller embedding size, remove this test.
251
- embedding_size = model.get_input_embeddings().weight.shape[0]
+ with deepspeed.zero.GatheredParameters(model.get_input_embeddings().weight, modifier_rank=None):
252
+ weights = model.get_input_embeddings().weight
253
+ embedding_size = weights.shape[0]
254
if len(tokenizer) > embedding_size:
255
model.resize_token_embeddings(len(tokenizer))
256
0 commit comments