-
Notifications
You must be signed in to change notification settings - Fork 72
[feat] HSTUMatch: scalar item-tower export view #518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1a8b7d1
aa76da2
7a9318f
0c1dcf1
2da7c29
4578c75
8cd5d3e
35b18cb
11ef729
e7836ff
1e885b5
59af910
f4925dc
ae17668
9f11f05
fc50ddd
d25eb5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -895,6 +895,7 @@ def export( | |
| checkpoint_path: Optional[str] = None, | ||
| asset_files: Optional[str] = None, | ||
| additional_export_config: Optional[Dict[str, str]] = None, | ||
| item_input_path: Optional[str] = None, | ||
| ) -> None: | ||
| """Export a EasyRec model. | ||
|
|
||
|
|
@@ -906,6 +907,9 @@ def export( | |
| asset_files (str, optional): more files will be copied to export_dir. | ||
| additional_export_config (dict, optional): extra key/value pairs merged | ||
| into model_acc.json (e.g. ``{"cand_seq_pk": "cand_seq"}`` for DlrmHSTU). | ||
| item_input_path (str, optional): override for the item tower's | ||
| predict-mode dataloader input path. When set, the item tower | ||
| reads from this path instead of ``train_input_path``. | ||
| """ | ||
| is_rank_zero = int(os.environ.get("RANK", 0)) == 0 | ||
|
|
||
|
|
@@ -936,6 +940,10 @@ def export( | |
| sampler_type=None, | ||
| ) | ||
| InferWrapper = ScriptWrapper | ||
| # Flip to inference *before* wrapping so view-dependent state | ||
| # (e.g. HSTUMatchItemTower's lazy properties, wrapper EmbeddingGroups) | ||
| # is snapshot from the scalar view. | ||
| model.set_is_inference(True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worth a short comment here naming the reason this must flip on all ranks pre-wrap: Also, asymmetry to watch for: the outer |
||
| model = InferWrapper(model) | ||
|
|
||
| if not checkpoint_path: | ||
|
|
@@ -959,13 +967,18 @@ def export( | |
| ) | ||
| tower = InferWrapper(wrapper(module, name)) | ||
| tower_export_dir = os.path.join(export_dir, name.replace("_tower", "")) | ||
| # item-tower-only; user tower falls back to `train_input_path`. | ||
| tower_data_input_path = ( | ||
| item_input_path if name == "item_tower" else None | ||
| ) | ||
|
Comment on lines
+970
to
+973
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The user-tower fallback to If you want to lock this in cheaply, a small unit test in |
||
| export_model( | ||
| ori_pipeline_config, | ||
| tower, | ||
| checkpoint_path, | ||
| tower_export_dir, | ||
| assets=assets, | ||
| additional_export_config=additional_export_config, | ||
| data_input_path=tower_data_input_path, | ||
| ) | ||
| elif isinstance(model.model, TDM): | ||
| for name, module in model.model.named_children(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,8 +15,17 @@ | |
| import torch.nn.functional as F | ||
| from torch import nn | ||
|
|
||
| from tzrec.datasets.utils import CAND_POS_LENGTHS, HARD_NEG_INDICES, Batch | ||
| from tzrec.features.feature import BaseFeature | ||
| from tzrec.datasets.utils import ( | ||
| BASE_DATA_GROUP, | ||
| CAND_POS_LENGTHS, | ||
| HARD_NEG_INDICES, | ||
| Batch, | ||
| ) | ||
| from tzrec.features.feature import ( | ||
| BaseFeature, | ||
| create_features, | ||
| project_grouped_sequence_feature_to_scalar, | ||
| ) | ||
| from tzrec.models.match_model import MatchModel, MatchTowerWoEG | ||
| from tzrec.modules.embedding import EmbeddingGroup | ||
| from tzrec.modules.gr.hstu_transducer import HSTUMatchEncoder | ||
|
|
@@ -151,8 +160,18 @@ def __init__( | |
| # tower_config.input names on the user-tower proto). Use the item-side | ||
| # tower_config.input here, which equals feature_groups[0].group_name. | ||
| self._group_name = tower_config.input | ||
| # MLP sized off the training candidate group; the scalar view has | ||
| # identical per-feature embedding dim. | ||
| candidate_dims = embedding_group.group_dims(f"{self._group_name}.sequence") | ||
| candidate_total_dim = sum(candidate_dims) | ||
|
|
||
| # Lazy caches for the scalar export view (populated on first | ||
| # property access after `set_is_inference(True)`). | ||
| self._features_scalar: Optional[List[BaseFeature]] = None | ||
| self._feature_groups_scalar: Optional[List[model_pb2.FeatureGroupConfig]] = None | ||
|
Comment on lines
+168
to
+171
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worth documenting the cache-validity invariant: |
||
| # `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`, | ||
| # so init `_is_inference` here. | ||
| self._is_inference: bool = False | ||
| if tower_config.HasField("mlp"): | ||
| self.mlp: torch.nn.Module = MLP( | ||
| in_features=candidate_total_dim, | ||
|
|
@@ -166,6 +185,48 @@ def __init__( | |
| if self._output_dim > 0: | ||
| self.output = nn.Linear(mlp_out_dim, output_dim) | ||
|
|
||
| @property | ||
| def features(self) -> List[BaseFeature]: | ||
| """Item features (training: grouped sub-features; export: scalar projection).""" | ||
| if self._is_inference: | ||
| if self._features_scalar is None: | ||
| self._build_scalar_features() | ||
| return self._features_scalar | ||
| return self._features | ||
|
|
||
| @property | ||
| def feature_groups(self) -> List[model_pb2.FeatureGroupConfig]: | ||
| """Item feature_groups in the current view (see ``features``).""" | ||
| if self._is_inference: | ||
| if self._feature_groups_scalar is None: | ||
| self._build_scalar_features() | ||
| return self._feature_groups_scalar | ||
| return self._feature_groups | ||
|
|
||
| def _build_scalar_features(self) -> None: | ||
| """Project each grouped sequence sub-feature into a scalar export feature.""" | ||
| scalar_configs = [ | ||
| project_grouped_sequence_feature_to_scalar(f) for f in self._features | ||
| ] | ||
| source = self._features[0] | ||
| scalar_features = create_features( | ||
| scalar_configs, | ||
| fg_mode=source.fg_mode, | ||
| neg_fields=None, | ||
| fg_encoded_multival_sep=source._fg_encoded_multival_sep, | ||
| force_base_data_group=any( | ||
| f.data_group == BASE_DATA_GROUP for f in self._features | ||
| ), | ||
| ) | ||
| self._features_scalar = scalar_features | ||
| self._feature_groups_scalar = [ | ||
| model_pb2.FeatureGroupConfig( | ||
| group_name=self._group_name, | ||
| feature_names=[f.name for f in scalar_features], | ||
| group_type=model_pb2.JAGGED_SEQUENCE, | ||
| ) | ||
| ] | ||
|
|
||
| def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: | ||
| """Forward the item tower. | ||
|
|
||
|
|
@@ -175,7 +236,9 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: | |
| Returns: | ||
| item embeddings of shape (sum_candidates, D). | ||
| """ | ||
| cand_emb = grouped_features[f"{self._group_name}.sequence"] | ||
| # `.sequence` (jagged) at training, `.query` (scalar) at export. | ||
| suffix = ".query" if self._is_inference else ".sequence" | ||
| cand_emb = grouped_features[self._group_name + suffix] | ||
| item_emb = self.mlp(cand_emb) | ||
| if self._output_dim > 0: | ||
| item_emb = self.output(item_emb) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two things:
hasattrguards are meaningful only forOverlapFeature(itsdefault_valueis commented out in the proto); for the other 11SeqFeatureConfig.featurevariants both fields are always present, so the guard reads as defensive code with no purpose. Worth namingOverlapFeaturein a comment.if not dst_msg.default_value/not dst_msg.HasField("value_dim")) are not exercised. The test infeature_test.pyonly covers the "both unset → materialize from source" path. Consider adding a case where the inner proto already sets e.g.default_value="-1"orvalue_dim=4and asserting the materialization does not overwrite. Same forvalue_dim > 1onraw_feature— currently onlyvalue_dim == 1(default) is asserted.