From 860dea142e1a175c5e969e9b1b9d3823f290f321 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 19:39:58 +0800 Subject: [PATCH 1/4] [refactor] expose features/feature_groups as @property on model base classes Add `features` / `feature_groups` `@property` accessors to the four classes that own the underlying `_features` / `_feature_groups` fields. Each property is a default read of the corresponding underscore field -- non-functional today, but gives subclasses a seam to override the surface per view (training vs export) without mutating the underscore fields. Touched: * `BaseModel` (tzrec/models/model.py): covers every ranking / matching / multi-task model and its `MatchModel` subclass. * `MatchTower` (tzrec/models/match_model.py): the base for `DSSMTower`, `DATTower`, `MINDUserTower`, `MINDItemTower`. Inherits from `BaseModule`, which does not own the underscore fields, so the property must be defined here for `TowerWrapper`'s property-based read (next commits) to succeed. * `MatchTowerWoEG` (tzrec/models/match_model.py): the base for `HSTUUserTower` / `HSTUMatchItemTower`. Inherits from `nn.Module` for the same reason. * `TDMEmbedding` (tzrec/models/tdm.py): an `nn.Module` exported separately by the TDM pipeline; owns its own `_features` / `_feature_groups` from the parent `EmbeddingGroup`. Pure additions -- no caller is rewired yet, so the existing underscore-field reads on these classes and their wrappers continue to work. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/models/match_model.py | 20 ++++++++++++++++++++ tzrec/models/model.py | 10 ++++++++++ tzrec/models/tdm.py | 10 ++++++++++ 3 files changed, 40 insertions(+) diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 433cd69a..49953f32 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -144,6 +144,16 @@ def __init__( self.group_variational_dropouts = None self.group_variational_dropout_loss = {} + @property + def features(self) -> List[BaseFeature]: + """Tower's features (default property forwarding to ``self._features``).""" + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Tower's feature_groups (default forward to ``self._feature_groups``).""" + return self._feature_groups + def init_input(self) -> None: """Build embedding group and group variational dropout.""" self.embedding_group = EmbeddingGroup(self._features, self._feature_groups) @@ -222,6 +232,16 @@ def __init__( self._feature_groups = feature_groups self._features = features + @property + def features(self) -> List[BaseFeature]: + """Tower's features (default property forwarding to ``self._features``).""" + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Tower's feature_groups (default forward to ``self._feature_groups``).""" + return self._feature_groups + class MatchModel(BaseModel): """Base model for match. diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 92ce5485..208e39c5 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -73,6 +73,16 @@ def __init__( self._train_metric_modules = nn.ModuleDict() + @property + def features(self) -> List[BaseFeature]: + """Model's features (default property forwarding to ``self._features``).""" + return self._features + + @property + def feature_groups(self) -> List[FeatureGroupConfig]: + """Model's feature_groups (default forward to ``self._feature_groups``).""" + return self._feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. diff --git a/tzrec/models/tdm.py b/tzrec/models/tdm.py index b706160c..0be8da1a 100644 --- a/tzrec/models/tdm.py +++ b/tzrec/models/tdm.py @@ -132,6 +132,16 @@ def __init__( EmbeddingGroup(seq_group_query_fea, self._feature_groups), ) + @property + def features(self) -> List[BaseFeature]: + """Query-side features consumed by the TDM embedding.""" + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Single feature_group consumed by the TDM embedding.""" + return self._feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the embedding. From 7a02f14be753396642f62df5ff7f52eebd3909c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 19:41:00 +0800 Subject: [PATCH 2/4] [refactor] migrate export_util / aot_utils to features/feature_groups property API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch the five call sites that read `model._features` / `model._feature_groups` (where `model` is a wrapper: `ScriptWrapper(...)` for non-match, `InferWrapper(TowerWrapper(...))` or `InferWrapper(TowerWoEGWrapper(...))` for match) to the property API introduced in the previous commit. Sites migrated: * `tzrec/utils/export_util.py::export_model_normal` — `features` argument to `create_dataloader` (line 164). * `tzrec/utils/export_util.py::export_model_aot` setup — second dataloader site (line 714). * `tzrec/utils/export_util.py::split_model` — `_compute_seq_share_groups` call site reads both `features` and `feature_groups` (lines 1163-1164). * `tzrec/acc/aot_utils.py::export_unified_model_aot` — sequence feature name set (line 470) and `_build_dynamic_shapes` call (lines 476-477). * `tzrec/acc/aot_utils.py::_build_dynamic_shapes` docstring — `model._features` / `model._feature_groups` references updated to `model.features` / `model.feature_groups` to match the new API. Today the wrappers still expose the underscore-field snapshot, so both forms work; this commit threads the property reads through the export pipeline so the next commit (which drops the underscore-field snapshots from the wrappers) is a no-op for these call sites. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/acc/aot_utils.py | 10 +++++----- tzrec/utils/export_util.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tzrec/acc/aot_utils.py b/tzrec/acc/aot_utils.py index 69e1b06d..74f08404 100644 --- a/tzrec/acc/aot_utils.py +++ b/tzrec/acc/aot_utils.py @@ -303,8 +303,8 @@ def _build_dynamic_shapes( Args: data: input tensor dict from Batch.to_dict(). - features: list of BaseFeature from model._features. - feature_groups: list of FeatureGroupConfig from model._feature_groups. + features: list of BaseFeature from model.features. + feature_groups: list of FeatureGroupConfig from model.feature_groups. Returns: dynamic_shapes dict for torch.export.export(). @@ -467,14 +467,14 @@ def export_unified_model_aot( # Pad any 0-size non-sequence sparse .values tensors so torch.export # doesn't specialize on the empty size (which conflicts with dynamic Dims). - seq_feat_names = {f.name for f in model._features if f.is_sequence} + seq_feat_names = {f.name for f in model.features if f.is_sequence} data = _pad_empty_sparse_values(data, seq_feat_names) # Build dynamic shapes using feature metadata for correct Dim grouping dynamic_shapes = _build_dynamic_shapes( data, - features=model._features, - feature_groups=model._feature_groups, + features=model.features, + feature_groups=model.feature_groups, ) logger.info("dynamic shapes=%s" % dynamic_shapes) diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index be786f9a..ee02ab8a 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -161,7 +161,7 @@ def export_model_normal( # make dataparser to get user feats before create model data_config = copy.deepcopy(pipeline_config.data_config) - features = cast(List[BaseFeature], model._features) + features = cast(List[BaseFeature], model.features) if acc_utils.is_cuda_export(): # export batch_size too large may OOM in compile phase max_batch_size = acc_utils.get_max_export_batch_size() @@ -711,7 +711,7 @@ def _all_keys_used_once( # make dataparser to get user feats before create model data_config = copy.deepcopy(pipeline_config.data_config) - features = cast(List[BaseFeature], model._features) + features = cast(List[BaseFeature], model.features) data_config.num_workers = 1 data_config.batch_size = acc_utils.get_max_export_batch_size() dataloader = create_dataloader( @@ -1160,8 +1160,8 @@ def _seq_feat_name(seq_name: str) -> str: dense_gm = _prune_unused_param_and_buffer(dense_gm) seq_share_groups = _compute_seq_share_groups( - features=cast(List[BaseFeature], model._features), - feature_groups=model._feature_groups, + features=cast(List[BaseFeature], model.features), + feature_groups=model.feature_groups, ) meta_info = { "seq_tensor_names": seq_tensor_names, From 355edca3b7ab2ee0a9ee8954a9726ebbd47d10eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 19:42:04 +0800 Subject: [PATCH 3/4] [refactor] drop _features/_feature_groups snapshots in wrappers; expose properties `ScriptWrapper`, `TowerWrapper`, and `TowerWoEGWrapper` previously mirrored their inner module's `_features` / `_feature_groups` onto themselves at construction time. With every wrapped class now exposing those metadata fields via `@property` (previous commits), the wrappers can drop the snapshot and forward live reads instead. Concretely: * `ScriptWrapper.__init__`: no longer copies `model._features` / `model._feature_groups`. `DataParser` is constructed from `self.model.features`. The class exposes `features` / `feature_groups` as `@property` reads of `self.model.features` / `self.model.feature_groups`. * `TowerWrapper.__init__` / `TowerWoEGWrapper.__init__`: same treatment, reading the wrapped tower via `getattr(self, self._tower_name)`. * `TowerWoEGWrapper.__init__` additionally switches its `EmbeddingGroup` build from `module._features, module._feature_groups` to `module.features, module.feature_groups`, so the construction-time snapshot that EmbeddingGroup needs (it owns nn.Parameters) goes through the same property surface as everything else. The downstream call sites in `tzrec/utils/export_util.py` and `tzrec/acc/aot_utils.py` were migrated to the property API in the previous commit, so this is a no-op for them. No other consumer reads the wrappers' underscore fields directly. Non-functional for DSSM/MIND/DAT/TDM/HSTU: properties return the same `_features` / `_feature_groups` lists the wrappers used to snapshot, identical to the pre-refactor behaviour. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/models/match_model.py | 26 +++++++++++++++++++++----- tzrec/models/model.py | 14 +++++++++++--- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 49953f32..1678f51c 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -477,10 +477,18 @@ class TowerWrapper(nn.Module): def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: super().__init__() setattr(self, tower_name, module) - self._features = module._features - self._feature_groups = module._feature_groups self._tower_name = tower_name + @property + def features(self) -> List[BaseFeature]: + """Live read of the wrapped tower's features (no snapshot).""" + return getattr(self, self._tower_name).features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Live read of the wrapped tower's feature_groups.""" + return getattr(self, self._tower_name).feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the tower. @@ -498,13 +506,21 @@ class TowerWoEGWrapper(nn.Module): def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: super().__init__() - self.embedding_group = EmbeddingGroup(module._features, module._feature_groups) + self.embedding_group = EmbeddingGroup(module.features, module.feature_groups) setattr(self, tower_name, module) - self._features = module._features - self._feature_groups = module._feature_groups self._tower_name = tower_name self._group_name = module._group_name + @property + def features(self) -> List[BaseFeature]: + """Live read of the wrapped tower's features (no snapshot).""" + return getattr(self, self._tower_name).features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Live read of the wrapped tower's feature_groups.""" + return getattr(self, self._tower_name).feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the tower. diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 208e39c5..40da5335 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -346,15 +346,23 @@ class ScriptWrapper(BaseModule): def __init__(self, module: nn.Module) -> None: super().__init__() self.model = module - self._features = self.model._features - self._feature_groups = self.model._feature_groups self._data_parser = DataParser( - self._features, + self.model.features, sampler_type=str(module.sampler_type) if hasattr(module, "sampler_type") else None, ) + @property + def features(self) -> List[BaseFeature]: + """Live read of the wrapped module's features (no snapshot).""" + return self.model.features + + @property + def feature_groups(self) -> List[FeatureGroupConfig]: + """Live read of the wrapped module's feature_groups.""" + return self.model.feature_groups + def get_batch( self, data: Dict[str, torch.Tensor], From 143cae5171fbbdfd76edc29657eeec3886d98ab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 20 May 2026 20:20:15 +0800 Subject: [PATCH 4/4] [chore] bump version to 1.2.12 Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tzrec/version.py b/tzrec/version.py index 27ce2b2b..1c5d53f1 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2.11" +__version__ = "1.2.12"