diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 49807dc1..377280af 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -4,6 +4,7 @@ import time from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union import torch +import torch.nn.functional as F import fms.utils.spyre.paged # noqa from aiu_fms_testing_utils.utils import get_pad_size