1515import torch .nn .functional as F
1616from torch import nn
1717
18- from tzrec .datasets .utils import CAND_POS_LENGTHS , HARD_NEG_INDICES , Batch
19- from tzrec .features .feature import BaseFeature
18+ from tzrec .datasets .utils import (
19+ BASE_DATA_GROUP ,
20+ CAND_POS_LENGTHS ,
21+ HARD_NEG_INDICES ,
22+ Batch ,
23+ )
24+ from tzrec .features .feature import (
25+ BaseFeature ,
26+ create_features ,
27+ project_grouped_sequence_feature_to_scalar ,
28+ )
2029from tzrec .models .match_model import MatchModel , MatchTowerWoEG
2130from tzrec .modules .embedding import EmbeddingGroup
2231from tzrec .modules .gr .hstu_transducer import HSTUMatchEncoder
@@ -151,8 +160,18 @@ def __init__(
151160 # tower_config.input names on the user-tower proto). Use the item-side
152161 # tower_config.input here, which equals feature_groups[0].group_name.
153162 self ._group_name = tower_config .input
163+ # MLP sized off the training candidate group; the scalar view has
164+ # identical per-feature embedding dim.
154165 candidate_dims = embedding_group .group_dims (f"{ self ._group_name } .sequence" )
155166 candidate_total_dim = sum (candidate_dims )
167+
168+ # Lazy caches for the scalar export view (populated on first
169+ # property access after `set_is_inference(True)`).
170+ self ._features_scalar : Optional [List [BaseFeature ]] = None
171+ self ._feature_groups_scalar : Optional [List [model_pb2 .FeatureGroupConfig ]] = None
172+ # `MatchTowerWoEG` derives from `nn.Module`, not `BaseModule`,
173+ # so init `_is_inference` here.
174+ self ._is_inference : bool = False
156175 if tower_config .HasField ("mlp" ):
157176 self .mlp : torch .nn .Module = MLP (
158177 in_features = candidate_total_dim ,
@@ -166,6 +185,48 @@ def __init__(
166185 if self ._output_dim > 0 :
167186 self .output = nn .Linear (mlp_out_dim , output_dim )
168187
188+ @property
189+ def features (self ) -> List [BaseFeature ]:
190+ """Item features (training: grouped sub-features; export: scalar projection)."""
191+ if self ._is_inference :
192+ if self ._features_scalar is None :
193+ self ._build_scalar_features ()
194+ return self ._features_scalar
195+ return self ._features
196+
197+ @property
198+ def feature_groups (self ) -> List [model_pb2 .FeatureGroupConfig ]:
199+ """Item feature_groups in the current view (see ``features``)."""
200+ if self ._is_inference :
201+ if self ._feature_groups_scalar is None :
202+ self ._build_scalar_features ()
203+ return self ._feature_groups_scalar
204+ return self ._feature_groups
205+
206+ def _build_scalar_features (self ) -> None :
207+ """Project each grouped sequence sub-feature into a scalar export feature."""
208+ scalar_configs = [
209+ project_grouped_sequence_feature_to_scalar (f ) for f in self ._features
210+ ]
211+ source = self ._features [0 ]
212+ scalar_features = create_features (
213+ scalar_configs ,
214+ fg_mode = source .fg_mode ,
215+ neg_fields = None ,
216+ fg_encoded_multival_sep = source ._fg_encoded_multival_sep ,
217+ force_base_data_group = any (
218+ f .data_group == BASE_DATA_GROUP for f in self ._features
219+ ),
220+ )
221+ self ._features_scalar = scalar_features
222+ self ._feature_groups_scalar = [
223+ model_pb2 .FeatureGroupConfig (
224+ group_name = self ._group_name ,
225+ feature_names = [f .name for f in scalar_features ],
226+ group_type = model_pb2 .JAGGED_SEQUENCE ,
227+ )
228+ ]
229+
169230 def forward (self , grouped_features : Dict [str , torch .Tensor ]) -> torch .Tensor :
170231 """Forward the item tower.
171232
@@ -175,7 +236,9 @@ def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor:
175236 Returns:
176237 item embeddings of shape (sum_candidates, D).
177238 """
178- cand_emb = grouped_features [f"{ self ._group_name } .sequence" ]
239+ # `.sequence` (jagged) at training, `.query` (scalar) at export.
240+ suffix = ".query" if self ._is_inference else ".sequence"
241+ cand_emb = grouped_features [self ._group_name + suffix ]
179242 item_emb = self .mlp (cand_emb )
180243 if self ._output_dim > 0 :
181244 item_emb = self .output (item_emb )
0 commit comments