Skip to content

Commit 6c8e2bd

Browse files
[bugfix] export: cache feature-permute order to avoid per-forward H2D sync (#527)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent aff7c2b commit 6c8e2bd

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

tzrec/utils/export_util.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,17 @@
3434
from torchrec.inference.modules import quantize_embeddings
3535
from torchrec.modules.embedding_configs import BaseEmbeddingConfig
3636
from torchrec.modules.embedding_modules import (
37+
EmbeddingBagCollection,
3738
EmbeddingBagCollectionInterface,
3839
EmbeddingCollection,
3940
EmbeddingCollectionInterface,
4041
)
4142
from torchrec.quant.embedding_modules import (
4243
EmbeddingCollection as QuantEmbeddingCollection,
4344
)
45+
from torchrec.quant.embedding_modules import (
46+
quant_prep_enable_cache_features_order,
47+
)
4448
from torchrec.sparse import jagged_tensor
4549

4650
from tzrec.acc import utils as acc_utils
@@ -213,9 +217,16 @@ def export_model_normal(
213217
logger.info("quantize embeddings...")
214218
additional_qconfig_spec_keys = []
215219
additional_mapping = {}
220+
cache_order_types = [EmbeddingBagCollection]
216221
if acc_utils.is_ec_quant():
217222
additional_qconfig_spec_keys.append(EmbeddingCollection)
218223
additional_mapping[EmbeddingCollection] = QuantEmbeddingCollection
224+
cache_order_types.append(EmbeddingCollection)
225+
# Cache the feature-permute order as an on-device buffer instead of
226+
# rebuilding `torch.tensor(order, device=cuda)` (a blocking H2D copy)
227+
# on every forward. Must run before quantize_embeddings so the quant
228+
# modules pick it up via `from_float`.
229+
quant_prep_enable_cache_features_order(model, cache_order_types)
219230
quantize_embeddings(
220231
model,
221232
dtype=acc_utils.quant_dtype(),

tzrec/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
__version__ = "1.2.13"
12+
__version__ = "1.2.14"

0 commit comments

Comments
 (0)