Skip to content

Commit b2027cb

Browse files
committed
get batch_size and seq_len from encoder outputs instead of inputs
1 parent befe416 commit b2027cb

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/cnlpt/modeling/models/projection_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def forward(
293293

294294
outputs = self.encoder(input_ids, **kwargs)
295295

296-
batch_size, seq_len = input_ids.shape
296+
batch_size, seq_len, _ = outputs.last_hidden_state.shape
297297

298298
logits = []
299299

0 commit comments

Comments
 (0)