Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tzrec/acc/aot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ def _build_dynamic_shapes(

Args:
data: input tensor dict from Batch.to_dict().
features: list of BaseFeature from model._features.
feature_groups: list of FeatureGroupConfig from model._feature_groups.
features: list of BaseFeature from model.features.
feature_groups: list of FeatureGroupConfig from model.feature_groups.

Returns:
dynamic_shapes dict for torch.export.export().
Expand Down Expand Up @@ -467,14 +467,14 @@ def export_unified_model_aot(

# Pad any 0-size non-sequence sparse .values tensors so torch.export
# doesn't specialize on the empty size (which conflicts with dynamic Dims).
seq_feat_names = {f.name for f in model._features if f.is_sequence}
seq_feat_names = {f.name for f in model.features if f.is_sequence}
data = _pad_empty_sparse_values(data, seq_feat_names)

# Build dynamic shapes using feature metadata for correct Dim grouping
dynamic_shapes = _build_dynamic_shapes(
data,
features=model._features,
feature_groups=model._feature_groups,
features=model.features,
feature_groups=model.feature_groups,
)
logger.info("dynamic shapes=%s" % dynamic_shapes)

Expand Down
46 changes: 41 additions & 5 deletions tzrec/models/match_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ def __init__(
self.group_variational_dropouts = None
self.group_variational_dropout_loss = {}

@property
def features(self) -> List[BaseFeature]:
"""Tower's features (default property forwarding to ``self._features``)."""
return self._features

@property
def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]:
"""Tower's feature_groups (default forward to ``self._feature_groups``)."""
return self._feature_groups

def init_input(self) -> None:
"""Build embedding group and group variational dropout."""
self.embedding_group = EmbeddingGroup(self._features, self._feature_groups)
Expand Down Expand Up @@ -222,6 +232,16 @@ def __init__(
self._feature_groups = feature_groups
self._features = features

@property
def features(self) -> List[BaseFeature]:
"""Tower's features (default property forwarding to ``self._features``)."""
return self._features

@property
def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]:
"""Tower's feature_groups (default forward to ``self._feature_groups``)."""
return self._feature_groups


class MatchModel(BaseModel):
"""Base model for match.
Expand Down Expand Up @@ -457,10 +477,18 @@ class TowerWrapper(nn.Module):
def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None:
super().__init__()
setattr(self, tower_name, module)
self._features = module._features
self._feature_groups = module._feature_groups
self._tower_name = tower_name

@property
def features(self) -> List[BaseFeature]:
"""Live read of the wrapped tower's features (no snapshot)."""
return getattr(self, self._tower_name).features

@property
def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]:
"""Live read of the wrapped tower's feature_groups."""
return getattr(self, self._tower_name).feature_groups

def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Forward the tower.

Expand All @@ -478,13 +506,21 @@ class TowerWoEGWrapper(nn.Module):

def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None:
super().__init__()
self.embedding_group = EmbeddingGroup(module._features, module._feature_groups)
self.embedding_group = EmbeddingGroup(module.features, module.feature_groups)
setattr(self, tower_name, module)
self._features = module._features
self._feature_groups = module._feature_groups
self._tower_name = tower_name
self._group_name = module._group_name
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent with the rest of the migration: this wrapper now reads features / feature_groups through the property surface but still reaches into module._group_name directly. The same "let subclasses override per-view" rationale applies to group_name — a MatchTowerWoEG subclass that wanted to expose a different group_name at export would have no seam.

Consider either (a) adding a group_name property on MatchTower / MatchTowerWoEG and reading through it here, or (b) at minimum a one-line comment noting _group_name is deliberately kept as a private contract between the tower and its wrapper. Symmetry with the rest of the PR would argue for (a).


@property
def features(self) -> List[BaseFeature]:
"""Live read of the wrapped tower's features (no snapshot)."""
return getattr(self, self._tower_name).features

@property
def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]:
"""Live read of the wrapped tower's feature_groups."""
return getattr(self, self._tower_name).feature_groups

def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Forward the tower.

Expand Down
24 changes: 21 additions & 3 deletions tzrec/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ def __init__(

self._train_metric_modules = nn.ModuleDict()

@property
def features(self) -> List[BaseFeature]:
"""Model's features (default property forwarding to ``self._features``)."""
return self._features

@property
def feature_groups(self) -> List[FeatureGroupConfig]:
"""Model's feature_groups (default forward to ``self._feature_groups``)."""
return self._feature_groups

def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Predict the model.

Expand Down Expand Up @@ -336,15 +346,23 @@ class ScriptWrapper(BaseModule):
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.model = module
self._features = self.model._features
self._feature_groups = self.model._feature_groups
self._data_parser = DataParser(
self._features,
self.model.features,
sampler_type=str(module.sampler_type)
if hasattr(module, "sampler_type")
else None,
)

@property
def features(self) -> List[BaseFeature]:
"""Live read of the wrapped module's features (no snapshot)."""
return self.model.features

@property
def feature_groups(self) -> List[FeatureGroupConfig]:
"""Live read of the wrapped module's feature_groups."""
return self.model.feature_groups

def get_batch(
self,
data: Dict[str, torch.Tensor],
Expand Down
10 changes: 10 additions & 0 deletions tzrec/models/tdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ def __init__(
EmbeddingGroup(seq_group_query_fea, self._feature_groups),
)

@property
def features(self) -> List[BaseFeature]:
"""Query-side features consumed by the TDM embedding."""
return self._features

@property
def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]:
"""Single feature_group consumed by the TDM embedding."""
return self._feature_groups

def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Forward the embedding.

Expand Down
8 changes: 4 additions & 4 deletions tzrec/utils/export_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def export_model_normal(

# make dataparser to get user feats before create model
data_config = copy.deepcopy(pipeline_config.data_config)
features = cast(List[BaseFeature], model._features)
features = cast(List[BaseFeature], model.features)
if acc_utils.is_cuda_export():
# export batch_size too large may OOM in compile phase
max_batch_size = acc_utils.get_max_export_batch_size()
Expand Down Expand Up @@ -711,7 +711,7 @@ def _all_keys_used_once(

# make dataparser to get user feats before create model
data_config = copy.deepcopy(pipeline_config.data_config)
features = cast(List[BaseFeature], model._features)
features = cast(List[BaseFeature], model.features)
data_config.num_workers = 1
data_config.batch_size = acc_utils.get_max_export_batch_size()
dataloader = create_dataloader(
Expand Down Expand Up @@ -1160,8 +1160,8 @@ def _seq_feat_name(seq_name: str) -> str:
dense_gm = _prune_unused_param_and_buffer(dense_gm)

seq_share_groups = _compute_seq_share_groups(
features=cast(List[BaseFeature], model._features),
feature_groups=model._feature_groups,
features=cast(List[BaseFeature], model.features),
feature_groups=model.feature_groups,
)
meta_info = {
"seq_tensor_names": seq_tensor_names,
Expand Down
2 changes: 1 addition & 1 deletion tzrec/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "1.2.11"
__version__ = "1.2.12"
Loading