diff --git a/docs/source/models/hstu_match.md b/docs/source/models/hstu_match.md index 3f013d2a..16dc46f9 100644 --- a/docs/source/models/hstu_match.md +++ b/docs/source/models/hstu_match.md @@ -256,13 +256,16 @@ model_config { ## 模型导出 -HSTU Match 模型导出时需要设置环境变量 `ENABLE_AOT=1` 启用 AOT Inductor 导出。例如: +HSTU Match 模型导出时,若使用 Triton kernel,需要设置环境变量 `ENABLE_AOT=1` 启用 AOT Inductor 导出。 + +同时需要通过命令行参数 `--item_input_path` 指定 item 侧的输入数据路径(一行一个 item 的 parquet,schema 与候选序列子特征对齐,例如包含 `video_id` 列)。item tower 导出时会从该路径读取一个样本 batch 用于 trace;user tower 不受影响,仍使用 `train_input_path`。例如: ``` ENABLE_AOT=1 torchrun --master_addr=localhost --master_port=32555 \ --nnodes=1 --nproc-per-node=1 --node_rank=0 \ -m tzrec.export \ --pipeline_config_path experiments/hstu_match/pipeline.config \ + --item_input_path experiments/hstu_match/item_data/*.parquet \ --export_dir experiments/hstu_match/export ``` diff --git a/scripts/ci/ci_data.sh b/scripts/ci/ci_data.sh index 0d2c1d67..bb9bff66 100644 --- a/scripts/ci/ci_data.sh +++ b/scripts/ci/ci_data.sh @@ -13,3 +13,4 @@ wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-mot-1k-eval-c4 wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-train-c4096-s100-f1892eabc70ae3407afe9ff5bca8cb5f.parquet -O data/test/kuairand-1k-match-train-c4096-s100.parquet wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-eval-c4096-s100-e4ca5e15d157efa723041cd05c127228.parquet -O data/test/kuairand-1k-match-eval-c4096-s100.parquet wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-item-gl-3d459148303acd9f838da108efcc40e5.txt -O data/test/kuairand-1k-match-item-gl.txt +wget https://tzrec.oss-accelerate.aliyuncs.com/data/test/kuairand-1k-match-item-c1-8dcadabdc3e9049ed9c2250565b4b134.parquet -O data/test/kuairand-1k-match-item-c1.parquet diff --git a/tzrec/export.py b/tzrec/export.py index 3e4eba45..2286ae61 100644 --- a/tzrec/export.py +++ b/tzrec/export.py @@ -48,6 +48,15 @@ help="JSON string of extra key/value pairs merged into model_acc.json, " 'e.g. \'{"cand_seq_pk": "cand_seq"}\' for DlrmHSTU.', ) + parser.add_argument( + "--item_input_path", + type=str, + default=None, + help="Optional input path for the item-tower's predict-mode " + "dataloader. When set, the item tower reads from this path " + "(a one-row-per-item table matching the scalar export view) " + "instead of `train_input_path`.", + ) args, extra_args = parser.parse_known_args() additional_export_config = ( @@ -62,4 +71,5 @@ checkpoint_path=args.checkpoint_path, asset_files=args.asset_files, additional_export_config=additional_export_config, + item_input_path=args.item_input_path, ) diff --git a/tzrec/features/feature.py b/tzrec/features/feature.py index bd5bb416..d7b90c23 100644 --- a/tzrec/features/feature.py +++ b/tzrec/features/feature.py @@ -1230,6 +1230,43 @@ def create_features( return features +def project_grouped_sequence_feature_to_scalar( + feature: BaseFeature, +) -> feature_pb2.FeatureConfig: + """Return a scalar export FeatureConfig for a grouped sequence sub-feature. + + Rewraps the inner ``SeqFeatureConfig`` as a top-level ``FeatureConfig`` + and materializes the source's effective ``default_value`` / ``value_dim`` + so the exported scalar feature matches the training sub-feature + (otherwise scalar mode defaults differ from sequence mode). + + Args: + feature: a grouped sequence sub-feature. + + Returns: + a fresh FeatureConfig for ``create_features()`` to build as scalar. + """ + if not feature.is_grouped_sequence: + raise ValueError( + "project_grouped_sequence_feature_to_scalar only accepts grouped " + f"sequence sub-features; got {feature.name} " + "(is_grouped_sequence=False)" + ) + src_cfg = feature.feature_config # SeqFeatureConfig + feat_type = src_cfg.WhichOneof("feature") + src_msg = getattr(src_cfg, feat_type) + + scalar_cfg = feature_pb2.FeatureConfig() + dst_msg = getattr(scalar_cfg, feat_type) + dst_msg.CopyFrom(src_msg) + + if hasattr(dst_msg, "default_value") and not dst_msg.default_value: + dst_msg.default_value = feature.default_value + if hasattr(dst_msg, "value_dim") and not dst_msg.HasField("value_dim"): + dst_msg.value_dim = feature.value_dim + return scalar_cfg + + def _copy_assets( feature: BaseFeature, asset_dir: Optional[str] = None, diff --git a/tzrec/features/feature_test.py b/tzrec/features/feature_test.py index 72267897..c72409c7 100644 --- a/tzrec/features/feature_test.py +++ b/tzrec/features/feature_test.py @@ -750,5 +750,78 @@ def test_sequence_input_names(self, fg_mode): ) +class ProjectGroupedSequenceFeatureToScalarTest(unittest.TestCase): + def _build_grouped(self, seq_sub_cfg): + feature_cfgs = [ + feature_pb2.FeatureConfig( + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="cand_seq", + sequence_delim="|", + sequence_length=100, + features=[seq_sub_cfg], + ) + ), + ] + return feature_lib.create_features(feature_cfgs) + + def test_projection_materializes_defaults_and_passes_through_create_features(self): + # id_feature: default_value / value_dim materialization + create_features. + id_sub_cfg = feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="video_id", + expression="item:video_id", + embedding_dim=32, + num_buckets=10000000, + ) + ) + features = self._build_grouped(id_sub_cfg) + self.assertEqual(len(features), 1) + sub_feature = features[0] + self.assertTrue(sub_feature.is_grouped_sequence) + # Sequence-effective defaults on the source. + self.assertEqual(sub_feature.default_value, "0") + self.assertEqual(sub_feature.value_dim, 1) + + scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar(sub_feature) + self.assertEqual(scalar_cfg.WhichOneof("feature"), "id_feature") + # Materialized onto the scalar proto. + self.assertEqual(scalar_cfg.id_feature.default_value, "0") + self.assertTrue(scalar_cfg.id_feature.HasField("value_dim")) + self.assertEqual(scalar_cfg.id_feature.value_dim, 1) + # Source proto not mutated. + self.assertEqual(sub_feature.feature_config.id_feature.default_value, "") + self.assertFalse(sub_feature.feature_config.id_feature.HasField("value_dim")) + + # create_features rebuilds it as a top-level scalar feature. + scalar_features = feature_lib.create_features([scalar_cfg]) + self.assertEqual(len(scalar_features), 1) + scalar = scalar_features[0] + self.assertEqual(scalar.name, "video_id") + self.assertFalse(scalar.is_grouped_sequence) + self.assertEqual(scalar.value_dim, 1) + + # raw_feature: confirms the helper isn't hard-coded to id_feature. + raw_sub_cfg = feature_pb2.SeqFeatureConfig( + raw_feature=feature_pb2.RawFeature( + feature_name="watch_time", expression="user:watch_time" + ) + ) + raw_features = self._build_grouped(raw_sub_cfg) + raw_scalar_cfg = feature_lib.project_grouped_sequence_feature_to_scalar( + raw_features[0] + ) + self.assertEqual(raw_scalar_cfg.WhichOneof("feature"), "raw_feature") + + def test_projection_rejects_non_grouped_feature(self): + feature_cfgs = [ + feature_pb2.FeatureConfig( + id_feature=feature_pb2.IdFeature(feature_name="user_id") + ), + ] + features = feature_lib.create_features(feature_cfgs) + with self.assertRaisesRegex(ValueError, "is_grouped_sequence=False"): + feature_lib.project_grouped_sequence_feature_to_scalar(features[0]) + + if __name__ == "__main__": unittest.main() diff --git a/tzrec/main.py b/tzrec/main.py index 45e62f04..f8fdaec6 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -895,6 +895,7 @@ def export( checkpoint_path: Optional[str] = None, asset_files: Optional[str] = None, additional_export_config: Optional[Dict[str, str]] = None, + item_input_path: Optional[str] = None, ) -> None: """Export a EasyRec model. @@ -906,6 +907,9 @@ def export( asset_files (str, optional): more files will be copied to export_dir. additional_export_config (dict, optional): extra key/value pairs merged into model_acc.json (e.g. ``{"cand_seq_pk": "cand_seq"}`` for DlrmHSTU). + item_input_path (str, optional): override for the item tower's + predict-mode dataloader input path. When set, the item tower + reads from this path instead of ``train_input_path``. """ is_rank_zero = int(os.environ.get("RANK", 0)) == 0 @@ -936,6 +940,10 @@ def export( sampler_type=None, ) InferWrapper = ScriptWrapper + # Flip to inference *before* wrapping so view-dependent state + # (e.g. HSTUMatchItemTower's lazy properties, wrapper EmbeddingGroups) + # is snapshot from the scalar view. + model.set_is_inference(True) model = InferWrapper(model) if not checkpoint_path: @@ -959,6 +967,10 @@ def export( ) tower = InferWrapper(wrapper(module, name)) tower_export_dir = os.path.join(export_dir, name.replace("_tower", "")) + # item-tower-only; user tower falls back to `train_input_path`. + tower_data_input_path = ( + item_input_path if name == "item_tower" else None + ) export_model( ori_pipeline_config, tower, @@ -966,6 +978,7 @@ def export( tower_export_dir, assets=assets, additional_export_config=additional_export_config, + data_input_path=tower_data_input_path, ) elif isinstance(model.model, TDM): for name, module in model.model.named_children(): diff --git a/tzrec/models/hstu.py b/tzrec/models/hstu.py index ae4be071..2e3c6acd 100644 --- a/tzrec/models/hstu.py +++ b/tzrec/models/hstu.py @@ -15,8 +15,17 @@ import torch.nn.functional as F from torch import nn -from tzrec.datasets.utils import CAND_POS_LENGTHS, HARD_NEG_INDICES, Batch -from tzrec.features.feature import BaseFeature +from tzrec.datasets.utils import ( + BASE_DATA_GROUP, + CAND_POS_LENGTHS, + HARD_NEG_INDICES, + Batch, +) +from tzrec.features.feature import ( + BaseFeature, + create_features, + project_grouped_sequence_feature_to_scalar, +) from tzrec.models.match_model import MatchModel, MatchTowerWoEG from tzrec.modules.embedding import EmbeddingGroup from tzrec.modules.gr.hstu_transducer import HSTUMatchEncoder @@ -151,8 +160,18 @@ def __init__( # tower_config.input names on the user-tower proto). Use the item-side # tower_config.input here, which equals feature_groups[0].group_name. self._group_name = tower_config.input + # MLP sized off the training candidate group; the scalar view has + # identical per-feature embedding dim. candidate_dims = embedding_group.group_dims(f"{self._group_name}.sequence") candidate_total_dim = sum(candidate_dims) + + # Lazy caches for the scalar export view (populated on first + # property access after `set_is_inference(True)`). + self._features_scalar: Optional[List[BaseFeature]] = None + self._feature_groups_scalar: Optional[List[model_pb2.FeatureGroupConfig]] = None + # `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`, + # so init `_is_inference` here. + self._is_inference: bool = False if tower_config.HasField("mlp"): self.mlp: torch.nn.Module = MLP( in_features=candidate_total_dim, @@ -166,6 +185,48 @@ def __init__( if self._output_dim > 0: self.output = nn.Linear(mlp_out_dim, output_dim) + @property + def features(self) -> List[BaseFeature]: + """Item features (training: grouped sub-features; export: scalar projection).""" + if self._is_inference: + if self._features_scalar is None: + self._build_scalar_features() + return self._features_scalar + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Item feature_groups in the current view (see ``features``).""" + if self._is_inference: + if self._feature_groups_scalar is None: + self._build_scalar_features() + return self._feature_groups_scalar + return self._feature_groups + + def _build_scalar_features(self) -> None: + """Project each grouped sequence sub-feature into a scalar export feature.""" + scalar_configs = [ + project_grouped_sequence_feature_to_scalar(f) for f in self._features + ] + source = self._features[0] + scalar_features = create_features( + scalar_configs, + fg_mode=source.fg_mode, + neg_fields=None, + fg_encoded_multival_sep=source._fg_encoded_multival_sep, + force_base_data_group=any( + f.data_group == BASE_DATA_GROUP for f in self._features + ), + ) + self._features_scalar = scalar_features + self._feature_groups_scalar = [ + model_pb2.FeatureGroupConfig( + group_name=self._group_name, + feature_names=[f.name for f in scalar_features], + group_type=model_pb2.JAGGED_SEQUENCE, + ) + ] + def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: """Forward the item tower. @@ -175,7 +236,9 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: Returns: item embeddings of shape (sum_candidates, D). """ - cand_emb = grouped_features[f"{self._group_name}.sequence"] + # `.sequence` (jagged) at training, `.query` (scalar) at export. + suffix = ".query" if self._is_inference else ".sequence" + cand_emb = grouped_features[self._group_name + suffix] item_emb = self.mlp(cand_emb) if self._output_dim > 0: item_emb = self.output(item_emb) diff --git a/tzrec/models/hstu_test.py b/tzrec/models/hstu_test.py index 1f2fa6a9..4f6c6f57 100644 --- a/tzrec/models/hstu_test.py +++ b/tzrec/models/hstu_test.py @@ -36,29 +36,51 @@ def _build_model(device: torch.device) -> HSTUMatch: - """Build an HSTUMatch model with standard test configuration.""" + """Build an HSTUMatch model with standard test configuration. + + Mirrors the production grouped-sequence pattern: `uih_seq` and + `cand_seq` each carry a `video_id` sub-feature with aligned bucket / + dim / `embedding_name` so the two flattened features share one + embedding table. `uih_seq` also carries the `historical_ts` raw + sub-feature for the timestamp dense path. + """ feature_cfgs = [ feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="historical_ids", + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="uih_seq", sequence_length=210, - embedding_dim=64, - num_buckets=3953, + features=[ + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="video_id", + embedding_dim=64, + num_buckets=1000, + embedding_name="video_id_emb", + ) + ), + feature_pb2.SeqFeatureConfig( + raw_feature=feature_pb2.RawFeature( + feature_name="historical_ts", + ) + ), + ], ) ), feature_pb2.FeatureConfig( - sequence_id_feature=feature_pb2.IdFeature( - feature_name="item_id", + sequence_feature=feature_pb2.SequenceFeature( + sequence_name="cand_seq", sequence_length=10, sequence_delim=";", - embedding_dim=64, - num_buckets=1000, - ) - ), - feature_pb2.FeatureConfig( - sequence_raw_feature=feature_pb2.RawFeature( - feature_name="historical_ts", - sequence_length=210, + features=[ + feature_pb2.SeqFeatureConfig( + id_feature=feature_pb2.IdFeature( + feature_name="video_id", + embedding_dim=64, + num_buckets=1000, + embedding_name="video_id_emb", + ) + ), + ], ) ), ] @@ -66,17 +88,17 @@ def _build_model(device: torch.device) -> HSTUMatch: feature_groups = [ model_pb2.FeatureGroupConfig( group_name="uih", - feature_names=["historical_ids"], + feature_names=["uih_seq__video_id"], group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, ), model_pb2.FeatureGroupConfig( group_name="candidate", - feature_names=["item_id"], + feature_names=["cand_seq__video_id"], group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, ), model_pb2.FeatureGroupConfig( group_name="uih_timestamp", - feature_names=["historical_ts"], + feature_names=["uih_seq__historical_ts"], group_type=model_pb2.FeatureGroupType.JAGGED_SEQUENCE, ), ] @@ -140,12 +162,12 @@ def _build_batch(device: torch.device) -> Batch: pos_lengths = [1, 1]. """ sparse_feature = KeyedJaggedTensor.from_lengths_sync( - keys=["historical_ids", "item_id"], + keys=["uih_seq__video_id", "cand_seq__video_id"], values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 100, 200, 101, 201]), lengths=torch.tensor([3, 4, 1, 3]), ) sequence_dense_features = { - "historical_ts": JaggedTensor( + "uih_seq__historical_ts": JaggedTensor( values=torch.tensor([[1], [2], [3], [4], [5], [6], [7]]), lengths=torch.tensor([3, 4]), ), @@ -199,20 +221,30 @@ def test_hstu_match(self, graph_type, kernel, device_str) -> None: batch = _build_batch(device=device) if graph_type == TestGraphType.JIT_SCRIPT: - hstu.set_is_inference(True) - hstu = create_test_model(hstu, graph_type) - predictions = hstu(batch.to_dict(), device) + hstu_wrapped = create_test_model(hstu, graph_type) + predictions = hstu_wrapped(batch.to_dict(), device) elif graph_type == TestGraphType.FX_TRACE: - hstu = create_test_model(hstu, graph_type) - predictions = hstu(batch) + hstu_wrapped = create_test_model(hstu, graph_type) + predictions = hstu_wrapped(batch) else: - hstu = TrainWrapper(hstu, device=device).to(device) - _, (_, predictions, _) = hstu(batch) + hstu_wrapped = TrainWrapper(hstu, device=device).to(device) + _, (_, predictions, _) = hstu_wrapped(batch) self.assertIn("similarity", predictions) # Q = sum(pos_lengths) = 2; column count = 1 (pos) + neg count. self.assertEqual(predictions["similarity"].size(0), 2) + # Scalar-view contract: set_is_inference(True) flips item_tower + # to the scalar export view (bare sub-feature names). + hstu.set_is_inference(True) + self.assertTrue(hstu.item_tower._is_inference) + scalar_features = hstu.item_tower.features + scalar_feature_groups = hstu.item_tower.feature_groups + self.assertEqual(scalar_features[0].name, "video_id") + self.assertFalse(scalar_features[0].is_grouped_sequence) + self.assertEqual(scalar_feature_groups[0].feature_names, ["video_id"]) + self.assertEqual(scalar_feature_groups[0].group_name, "candidate") + if __name__ == "__main__": unittest.main() diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py index 8764b1d9..6c2ff407 100644 --- a/tzrec/tests/match_integration_test.py +++ b/tzrec/tests/match_integration_test.py @@ -378,9 +378,50 @@ def test_hstu_with_fg_train_eval(self): ) if self.success: self.success = utils.test_eval( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir + os.path.join(self.test_dir, "pipeline.config"), + self.test_dir, + ) + if self.success: + self.success = utils.test_export( + os.path.join(self.test_dir, "pipeline.config"), + self.test_dir, + env_str="ENABLE_AOT=1", + item_input_path="data/test/kuairand-1k-match-item-c1.parquet", + ) + if self.success: + # Item tower scalar export view: predict over the item-only + # parquet (one row per video_id) and emit item embeddings. + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/item"), + predict_input_path="data/test/kuairand-1k-match-item-c1.parquet", + predict_output_path=os.path.join(self.test_dir, "item_emb"), + reserved_columns="video_id", + output_columns="item_tower_emb", + test_dir=self.test_dir, + ) + if self.success: + # User tower keeps the training-shape sequence view: predict + # over the eval parquet (which carries the user-side columns). + self.success = utils.test_predict( + scripted_model_path=os.path.join(self.test_dir, "export/user"), + predict_input_path=( + "data/test/kuairand-1k-match-eval-c4096-s100.parquet" + ), + predict_output_path=os.path.join(self.test_dir, "user_emb"), + reserved_columns="user_id", + output_columns="user_tower_emb", + test_dir=self.test_dir, ) self.assertTrue(self.success) + for side in ("user", "item"): + self.assertTrue( + os.path.exists( + os.path.join( + self.test_dir, f"export/{side}/scripted_sparse_model.pt" + ) + ), + f"missing AOT scripted sparse model for {side} tower", + ) if __name__ == "__main__": diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index 333a8a2f..ef2fd3dc 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -1040,6 +1040,7 @@ def test_export( asset_files: str = "", env_str: str = "", additional_export_config: str = "", + item_input_path: str = "", ) -> bool: """Run export integration test.""" log_dir = os.path.join(test_dir, "log_export") @@ -1056,7 +1057,9 @@ def test_export( if asset_files: cmd_str += f"--asset_files {asset_files} " if additional_export_config: - cmd_str += f"--additional_export_config '{additional_export_config}'" + cmd_str += f"--additional_export_config '{additional_export_config}' " + if item_input_path: + cmd_str += f"--item_input_path {item_input_path} " return misc_util.run_cmd( cmd_str, os.path.join(test_dir, "log_export.txt"), timeout=1800 diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index ee02ab8a..05642f7f 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -80,8 +80,13 @@ def export_model( save_dir: str, assets: Optional[List[str]] = None, additional_export_config: Optional[Dict[str, str]] = None, + data_input_path: Optional[str] = None, ) -> None: - """Export a EasyRec model, may be a part of model in PipelineConfig.""" + """Export a EasyRec model, may be a part of model in PipelineConfig. + + `data_input_path` (optional): override for the predict-mode dataloader + input path; falls back to `pipeline_config.train_input_path` when None. + """ use_rtp = env_util.use_rtp() impl = export_rtp_model if use_rtp else export_model_normal @@ -100,6 +105,7 @@ def export_model( assets=assets, use_local_cache_dir=use_local_cache_dir, additional_export_config=additional_export_config, + data_input_path=data_input_path, ) if use_local_cache_dir and int(os.environ.get("LOCAL_RANK", 0)) == 0: logger.info(f"uploading {local_path} to {save_dir}.") @@ -142,6 +148,7 @@ def export_model_normal( save_dir: str, assets: Optional[List[str]] = None, additional_export_config: Optional[Dict[str, str]] = None, + data_input_path: Optional[str] = None, **kwargs: Any, ) -> None: """Export a EasyRec model on aliyun.""" @@ -168,9 +175,8 @@ def export_model_normal( data_config.batch_size = min(data_config.batch_size, max_batch_size) logger.info("using new batch_size: %s in export", data_config.batch_size) data_config.num_workers = 1 - dataloader = create_dataloader( - data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT - ) + input_path = data_input_path or pipeline_config.train_input_path + dataloader = create_dataloader(data_config, features, input_path, mode=Mode.PREDICT) ckpt_param_map_path = None if checkpoint_path: @@ -190,8 +196,6 @@ def export_model_normal( if is_rank_zero: if not os.path.exists(save_dir): os.makedirs(save_dir) - model.set_is_inference(True) - init_parameters(model, torch.device("cpu")) checkpoint_util.restore_model( checkpoint_path, model, ckpt_param_map_path=ckpt_param_map_path @@ -666,6 +670,7 @@ def export_rtp_model( save_dir: str, assets: Optional[List[str]] = None, use_local_cache_dir: bool = False, + data_input_path: Optional[str] = None, **kwargs: Any, ) -> None: """Export a EasyRec model on RTP.""" @@ -714,14 +719,11 @@ def _all_keys_used_once( features = cast(List[BaseFeature], model.features) data_config.num_workers = 1 data_config.batch_size = acc_utils.get_max_export_batch_size() - dataloader = create_dataloader( - data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT - ) + input_path = data_input_path or pipeline_config.train_input_path + dataloader = create_dataloader(data_config, features, input_path, mode=Mode.PREDICT) batch = next(iter(dataloader)) data = batch.to(device).to_dict(sparse_dtype=torch.int64) - model.set_is_inference(True) - # Build Sharded Model planner = create_planner( device=device, @@ -1042,7 +1044,6 @@ def split_model( if not os.path.exists(graph_dir): os.makedirs(graph_dir) - model.set_is_inference(True) model.eval() tracer = Tracer()