1414import torch
1515from hypothesis import Verbosity , assume , given , settings
1616from hypothesis import strategies as st
17- from torchrec import JaggedTensor , KeyedJaggedTensor
17+ from torchrec import JaggedTensor , KeyedJaggedTensor , KeyedTensor
1818
1919from tzrec .datasets .utils import BASE_DATA_GROUP , CAND_POS_LENGTHS , Batch
2020from tzrec .features .feature import create_features
@@ -43,6 +43,10 @@ def _build_model(device: torch.device) -> HSTUMatch:
4343 dim / `embedding_name` so the two flattened features share one
4444 embedding table. `uih_seq` also carries the `historical_ts` raw
4545 sub-feature for the timestamp dense path.
46+
47+ Time encoding is on, with a scalar ``request_time`` raw feature exposed
48+ through a ``query_time`` DEEP group — the per-row time-bias anchor
49+ (mirrors the production config).
4650 """
4751 feature_cfgs = [
4852 feature_pb2 .FeatureConfig (
@@ -84,6 +88,11 @@ def _build_model(device: torch.device) -> HSTUMatch:
8488 )
8589 ),
8690 ]
91+ feature_cfgs .append (
92+ feature_pb2 .FeatureConfig (
93+ raw_feature = feature_pb2 .RawFeature (feature_name = "request_time" )
94+ )
95+ )
8796 features = create_features (feature_cfgs )
8897 feature_groups = [
8998 model_pb2 .FeatureGroupConfig (
@@ -102,6 +111,13 @@ def _build_model(device: torch.device) -> HSTUMatch:
102111 group_type = model_pb2 .FeatureGroupType .JAGGED_SEQUENCE ,
103112 ),
104113 ]
114+ feature_groups .append (
115+ model_pb2 .FeatureGroupConfig (
116+ group_name = "query_time" ,
117+ feature_names = ["request_time" ],
118+ group_type = model_pb2 .FeatureGroupType .DEEP ,
119+ )
120+ )
105121 model_config = model_pb2 .ModelConfig (
106122 feature_groups = feature_groups ,
107123 hstu_match = match_model_pb2 .HSTUMatch (
@@ -120,6 +136,8 @@ def _build_model(device: torch.device) -> HSTUMatch:
120136 attn_num_layers = 2 ,
121137 positional_encoder = module_pb2 .GRPositionalEncoder (
122138 num_position_buckets = 512 ,
139+ num_time_buckets = 512 ,
140+ use_time_encoding = True ,
123141 ),
124142 input_preprocessor = module_pb2 .GRInputPreprocessor (
125143 uih_preprocessor = module_pb2 .GRUIHPreprocessor (),
@@ -160,6 +178,9 @@ def _build_batch(device: torch.device) -> Batch:
160178 Candidates: row 0 = [pos_0]; row 1 (last) = [pos_1, simple_neg_0,
161179 simple_neg_1] -- the shared simple-neg pool sits in the last row's suffix.
162180 pos_lengths = [1, 1].
181+
182+ A per-row ``request_time`` dense scalar (strictly after each user's last
183+ UIH event at ts 3 / 7) is included as the time-bias anchor.
163184 """
164185 sparse_feature = KeyedJaggedTensor .from_lengths_sync (
165186 keys = ["uih_seq__video_id" , "cand_seq__video_id" ],
@@ -172,7 +193,14 @@ def _build_batch(device: torch.device) -> Batch:
172193 lengths = torch .tensor ([3 , 4 ]),
173194 ),
174195 }
196+ dense_features = {
197+ BASE_DATA_GROUP : KeyedTensor .from_tensor_list (
198+ keys = ["request_time" ],
199+ tensors = [torch .tensor ([[100.0 ], [100.0 ]])],
200+ )
201+ }
175202 return Batch (
203+ dense_features = dense_features ,
176204 sparse_features = {BASE_DATA_GROUP : sparse_feature },
177205 sequence_dense_features = sequence_dense_features ,
178206 jagged_labels = {
@@ -217,6 +245,9 @@ def test_hstu_match(self, graph_type, kernel, device_str) -> None:
217245
218246 device = torch .device (device_str )
219247 hstu = _build_model (device = device )
248+ # The query_time DEEP group is detected and threaded as the per-row
249+ # time-bias anchor (request-time anchoring, not the last UIH event).
250+ self .assertEqual (hstu .user_tower ._hstu_encoder ._query_time_key , "query_time" )
220251 hstu .set_kernel (kernel )
221252 batch = _build_batch (device = device )
222253
0 commit comments