|
34 | 34 | from torchrec.inference.modules import quantize_embeddings |
35 | 35 | from torchrec.modules.embedding_configs import BaseEmbeddingConfig |
36 | 36 | from torchrec.modules.embedding_modules import ( |
| 37 | + EmbeddingBagCollection, |
37 | 38 | EmbeddingBagCollectionInterface, |
38 | 39 | EmbeddingCollection, |
39 | 40 | EmbeddingCollectionInterface, |
40 | 41 | ) |
41 | 42 | from torchrec.quant.embedding_modules import ( |
42 | 43 | EmbeddingCollection as QuantEmbeddingCollection, |
43 | 44 | ) |
| 45 | +from torchrec.quant.embedding_modules import ( |
| 46 | + quant_prep_enable_cache_features_order, |
| 47 | +) |
44 | 48 | from torchrec.sparse import jagged_tensor |
45 | 49 |
|
46 | 50 | from tzrec.acc import utils as acc_utils |
@@ -213,9 +217,16 @@ def export_model_normal( |
213 | 217 | logger.info("quantize embeddings...") |
214 | 218 | additional_qconfig_spec_keys = [] |
215 | 219 | additional_mapping = {} |
| 220 | + cache_order_types = [EmbeddingBagCollection] |
216 | 221 | if acc_utils.is_ec_quant(): |
217 | 222 | additional_qconfig_spec_keys.append(EmbeddingCollection) |
218 | 223 | additional_mapping[EmbeddingCollection] = QuantEmbeddingCollection |
| 224 | + cache_order_types.append(EmbeddingCollection) |
| 225 | + # Cache the feature-permute order as an on-device buffer instead of |
| 226 | + # rebuilding `torch.tensor(order, device=cuda)` (a blocking H2D copy) |
| 227 | + # on every forward. Must run before quantize_embeddings so the quant |
| 228 | + # modules pick it up via `from_float`. |
| 229 | + quant_prep_enable_cache_features_order(model, cache_order_types) |
219 | 230 | quantize_embeddings( |
220 | 231 | model, |
221 | 232 | dtype=acc_utils.quant_dtype(), |
|
0 commit comments