Skip to content

Commit f8cf40c

Browse files
committed
Add extra inference call for issue #173
1 parent 838d14d commit f8cf40c

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,18 @@ def __load_validation_info(
426426
**extra_kwargs,
427427
)
428428

429+
# do an extra inference call to workaround the issue on z/OS where the first inference
430+
# result is always incorrect during multi-AIU (issue 173)
431+
extract_validation_information(
432+
model,
433+
input_ids,
434+
max_new_tokens,
435+
post_iteration_hook=None,
436+
last_n_tokens=64,
437+
prefill_chunk_size=args.prefill_chunk_size,
438+
**extra_kwargs,
439+
)
440+
429441
if USE_DISTRIBUTED:
430442
# wait for rank0 to be finished as it is the only one generating the criteria json
431443
# this is needed since otherwise we may run into a race condition

0 commit comments

Comments
 (0)