diff --git a/tzrec/acc/aot_utils.py b/tzrec/acc/aot_utils.py index 69e1b06d..74f08404 100644 --- a/tzrec/acc/aot_utils.py +++ b/tzrec/acc/aot_utils.py @@ -303,8 +303,8 @@ def _build_dynamic_shapes( Args: data: input tensor dict from Batch.to_dict(). - features: list of BaseFeature from model._features. - feature_groups: list of FeatureGroupConfig from model._feature_groups. + features: list of BaseFeature from model.features. + feature_groups: list of FeatureGroupConfig from model.feature_groups. Returns: dynamic_shapes dict for torch.export.export(). @@ -467,14 +467,14 @@ def export_unified_model_aot( # Pad any 0-size non-sequence sparse .values tensors so torch.export # doesn't specialize on the empty size (which conflicts with dynamic Dims). - seq_feat_names = {f.name for f in model._features if f.is_sequence} + seq_feat_names = {f.name for f in model.features if f.is_sequence} data = _pad_empty_sparse_values(data, seq_feat_names) # Build dynamic shapes using feature metadata for correct Dim grouping dynamic_shapes = _build_dynamic_shapes( data, - features=model._features, - feature_groups=model._feature_groups, + features=model.features, + feature_groups=model.feature_groups, ) logger.info("dynamic shapes=%s" % dynamic_shapes) diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 433cd69a..1678f51c 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -144,6 +144,16 @@ def __init__( self.group_variational_dropouts = None self.group_variational_dropout_loss = {} + @property + def features(self) -> List[BaseFeature]: + """Tower's features (default property forwarding to ``self._features``).""" + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Tower's feature_groups (default forward to ``self._feature_groups``).""" + return self._feature_groups + def init_input(self) -> None: """Build embedding group and group variational dropout.""" self.embedding_group = EmbeddingGroup(self._features, self._feature_groups) @@ -222,6 +232,16 @@ def __init__( self._feature_groups = feature_groups self._features = features + @property + def features(self) -> List[BaseFeature]: + """Tower's features (default property forwarding to ``self._features``).""" + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Tower's feature_groups (default forward to ``self._feature_groups``).""" + return self._feature_groups + class MatchModel(BaseModel): """Base model for match. @@ -457,10 +477,18 @@ class TowerWrapper(nn.Module): def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: super().__init__() setattr(self, tower_name, module) - self._features = module._features - self._feature_groups = module._feature_groups self._tower_name = tower_name + @property + def features(self) -> List[BaseFeature]: + """Live read of the wrapped tower's features (no snapshot).""" + return getattr(self, self._tower_name).features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Live read of the wrapped tower's feature_groups.""" + return getattr(self, self._tower_name).feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the tower. @@ -478,13 +506,21 @@ class TowerWoEGWrapper(nn.Module): def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: super().__init__() - self.embedding_group = EmbeddingGroup(module._features, module._feature_groups) + self.embedding_group = EmbeddingGroup(module.features, module.feature_groups) setattr(self, tower_name, module) - self._features = module._features - self._feature_groups = module._feature_groups self._tower_name = tower_name self._group_name = module._group_name + @property + def features(self) -> List[BaseFeature]: + """Live read of the wrapped tower's features (no snapshot).""" + return getattr(self, self._tower_name).features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Live read of the wrapped tower's feature_groups.""" + return getattr(self, self._tower_name).feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the tower. diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 92ce5485..40da5335 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -73,6 +73,16 @@ def __init__( self._train_metric_modules = nn.ModuleDict() + @property + def features(self) -> List[BaseFeature]: + """Model's features (default property forwarding to ``self._features``).""" + return self._features + + @property + def feature_groups(self) -> List[FeatureGroupConfig]: + """Model's feature_groups (default forward to ``self._feature_groups``).""" + return self._feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. @@ -336,15 +346,23 @@ class ScriptWrapper(BaseModule): def __init__(self, module: nn.Module) -> None: super().__init__() self.model = module - self._features = self.model._features - self._feature_groups = self.model._feature_groups self._data_parser = DataParser( - self._features, + self.model.features, sampler_type=str(module.sampler_type) if hasattr(module, "sampler_type") else None, ) + @property + def features(self) -> List[BaseFeature]: + """Live read of the wrapped module's features (no snapshot).""" + return self.model.features + + @property + def feature_groups(self) -> List[FeatureGroupConfig]: + """Live read of the wrapped module's feature_groups.""" + return self.model.feature_groups + def get_batch( self, data: Dict[str, torch.Tensor], diff --git a/tzrec/models/tdm.py b/tzrec/models/tdm.py index b706160c..0be8da1a 100644 --- a/tzrec/models/tdm.py +++ b/tzrec/models/tdm.py @@ -132,6 +132,16 @@ def __init__( EmbeddingGroup(seq_group_query_fea, self._feature_groups), ) + @property + def features(self) -> List[BaseFeature]: + """Query-side features consumed by the TDM embedding.""" + return self._features + + @property + def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: + """Single feature_group consumed by the TDM embedding.""" + return self._feature_groups + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the embedding. diff --git a/tzrec/utils/export_util.py b/tzrec/utils/export_util.py index be786f9a..ee02ab8a 100644 --- a/tzrec/utils/export_util.py +++ b/tzrec/utils/export_util.py @@ -161,7 +161,7 @@ def export_model_normal( # make dataparser to get user feats before create model data_config = copy.deepcopy(pipeline_config.data_config) - features = cast(List[BaseFeature], model._features) + features = cast(List[BaseFeature], model.features) if acc_utils.is_cuda_export(): # export batch_size too large may OOM in compile phase max_batch_size = acc_utils.get_max_export_batch_size() @@ -711,7 +711,7 @@ def _all_keys_used_once( # make dataparser to get user feats before create model data_config = copy.deepcopy(pipeline_config.data_config) - features = cast(List[BaseFeature], model._features) + features = cast(List[BaseFeature], model.features) data_config.num_workers = 1 data_config.batch_size = acc_utils.get_max_export_batch_size() dataloader = create_dataloader( @@ -1160,8 +1160,8 @@ def _seq_feat_name(seq_name: str) -> str: dense_gm = _prune_unused_param_and_buffer(dense_gm) seq_share_groups = _compute_seq_share_groups( - features=cast(List[BaseFeature], model._features), - feature_groups=model._feature_groups, + features=cast(List[BaseFeature], model.features), + feature_groups=model.feature_groups, ) meta_info = { "seq_tensor_names": seq_tensor_names, diff --git a/tzrec/version.py b/tzrec/version.py index 27ce2b2b..1c5d53f1 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2.11" +__version__ = "1.2.12"