forked from xai-org/x-algorithm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrunners.py
More file actions
807 lines (656 loc) · 25.9 KB
/
runners.py
File metadata and controls
807 lines (656 loc) · 25.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from grok import TrainingState
from recsys_retrieval_model import PhoenixRetrievalModelConfig
from recsys_retrieval_model import RetrievalOutput as ModelRetrievalOutput
from recsys_model import (
PhoenixModelConfig,
RecsysBatch,
RecsysEmbeddings,
RecsysModelOutput,
)
rank_logger = logging.getLogger("rank")
def load_model_params(checkpoint_path: str) -> hk.Params:
"""Load model parameters from an exported checkpoint.
Args:
checkpoint_path: Path to model_params.npz file
Returns:
Haiku params dict (nested FrozenDict)
"""
data = np.load(checkpoint_path, allow_pickle=True)
params: dict = {}
for key in data.files:
parts = key.split("/")
module_path = "/".join(parts[:-1])
param_name = parts[-1]
params.setdefault(module_path, {})[param_name] = jnp.array(data[key])
return hk.data_structures.to_haiku_dict(params)
def load_embedding_table(path: str) -> np.ndarray:
"""Load an embedding table from an exported checkpoint.
Args:
path: Path to embedding_tables.npz file
Returns:
Dict with 'user_embeddings', 'item_embeddings', 'author_embeddings' arrays
"""
return dict(np.load(path))
def create_dummy_batch_from_config(
hash_config: Any,
history_len: int,
num_candidates: int,
num_actions: int,
batch_size: int = 1,
) -> RecsysBatch:
"""Create a dummy batch for initialization.
Args:
hash_config: HashConfig with num_user_hashes, num_item_hashes, num_author_hashes
history_len: History sequence length
num_candidates: Number of candidates
num_actions: Number of action types
batch_size: Batch size
Returns:
RecsysBatch with zeros
"""
return RecsysBatch(
user_hashes=np.zeros((batch_size, hash_config.num_user_hashes), dtype=np.int32),
history_post_hashes=np.zeros(
(batch_size, history_len, hash_config.num_item_hashes), dtype=np.int32
),
history_author_hashes=np.zeros(
(batch_size, history_len, hash_config.num_author_hashes), dtype=np.int32
),
history_actions=np.zeros((batch_size, history_len, num_actions), dtype=np.float32),
history_product_surface=np.zeros((batch_size, history_len), dtype=np.int32),
candidate_post_hashes=np.zeros(
(batch_size, num_candidates, hash_config.num_item_hashes), dtype=np.int32
),
candidate_author_hashes=np.zeros(
(batch_size, num_candidates, hash_config.num_author_hashes), dtype=np.int32
),
candidate_product_surface=np.zeros((batch_size, num_candidates), dtype=np.int32),
)
def create_dummy_embeddings_from_config(
hash_config: Any,
emb_size: int,
history_len: int,
num_candidates: int,
batch_size: int = 1,
) -> RecsysEmbeddings:
"""Create dummy embeddings for initialization.
Args:
hash_config: HashConfig with num_user_hashes, num_item_hashes, num_author_hashes
emb_size: Embedding dimension
history_len: History sequence length
num_candidates: Number of candidates
batch_size: Batch size
Returns:
RecsysEmbeddings with zeros
"""
return RecsysEmbeddings(
user_embeddings=np.zeros(
(batch_size, hash_config.num_user_hashes, emb_size), dtype=np.float32
),
history_post_embeddings=np.zeros(
(batch_size, history_len, hash_config.num_item_hashes, emb_size), dtype=np.float32
),
candidate_post_embeddings=np.zeros(
(batch_size, num_candidates, hash_config.num_item_hashes, emb_size),
dtype=np.float32,
),
history_author_embeddings=np.zeros(
(batch_size, history_len, hash_config.num_author_hashes, emb_size), dtype=np.float32
),
candidate_author_embeddings=np.zeros(
(batch_size, num_candidates, hash_config.num_author_hashes, emb_size),
dtype=np.float32,
),
)
@dataclass
class BaseModelRunner(ABC):
"""Base class for model runners with shared initialization logic."""
bs_per_device: float = 2.0
rng_seed: int = 42
@property
@abstractmethod
def model(self) -> Any:
"""Return the model config."""
pass
@property
def _model_name(self) -> str:
"""Return model name for logging."""
return "model"
@abstractmethod
def make_forward_fn(self):
"""Create the forward function. Must be implemented by subclasses."""
pass
def initialize(self):
"""Initialize the model runner."""
self.model.initialize()
self.model.fprop_dtype = jnp.bfloat16
num_local_gpus = len(jax.local_devices())
self.batch_size = max(1, int(self.bs_per_device * num_local_gpus))
rank_logger.info(f"Initializing {self._model_name}...")
self.forward = self.make_forward_fn()
@dataclass
class BaseInferenceRunner(ABC):
"""Base class for inference runners with shared dummy data creation."""
name: str
@property
@abstractmethod
def runner(self) -> BaseModelRunner:
"""Return the underlying model runner."""
pass
def _get_num_actions(self) -> int:
"""Get number of actions. Override in subclasses if needed."""
model_config = self.runner.model
if hasattr(model_config, "num_actions"):
return model_config.num_actions
return 19
def create_dummy_batch(self, batch_size: int = 1) -> RecsysBatch:
"""Create a dummy batch for initialization."""
model_config = self.runner.model
return create_dummy_batch_from_config(
hash_config=model_config.hash_config,
history_len=model_config.history_seq_len,
num_candidates=model_config.candidate_seq_len,
num_actions=self._get_num_actions(),
batch_size=batch_size,
)
def create_dummy_embeddings(self, batch_size: int = 1) -> RecsysEmbeddings:
"""Create dummy embeddings for initialization."""
model_config = self.runner.model
return create_dummy_embeddings_from_config(
hash_config=model_config.hash_config,
emb_size=model_config.emb_size,
history_len=model_config.history_seq_len,
num_candidates=model_config.candidate_seq_len,
batch_size=batch_size,
)
@abstractmethod
def initialize(self):
"""Initialize the inference runner. Must be implemented by subclasses."""
pass
ACTIONS: List[str] = [
"favorite_score",
"reply_score",
"repost_score",
"photo_expand_score",
"click_score",
"profile_click_score",
"vqv_score",
"share_score",
"share_via_dm_score",
"share_via_copy_link_score",
"dwell_score",
"quote_score",
"quoted_click_score",
"follow_author_score",
"not_interested_score",
"block_author_score",
"mute_author_score",
"report_score",
"dwell_time",
]
CONTINUOUS_ACTIONS: List[str] = [
"reserved",
"dwell_time",
"video_watch_time",
"scroll_depth",
"reserved_3",
"reserved_4",
"reserved_5",
"reserved_6",
]
NEGATIVE_FEEDBACK_INDICES: List[int] = [
14,
15,
16,
17,
]
class RankingOutput(NamedTuple):
"""Output from ranking candidates.
Contains both the raw scores array and individual probability fields
for each engagement type.
"""
scores: jax.Array
ranked_indices: jax.Array
p_favorite_score: jax.Array
p_reply_score: jax.Array
p_repost_score: jax.Array
p_photo_expand_score: jax.Array
p_click_score: jax.Array
p_profile_click_score: jax.Array
p_vqv_score: jax.Array
p_share_score: jax.Array
p_share_via_dm_score: jax.Array
p_share_via_copy_link_score: jax.Array
p_dwell_score: jax.Array
p_quote_score: jax.Array
p_quoted_click_score: jax.Array
p_follow_author_score: jax.Array
p_not_interested_score: jax.Array
p_block_author_score: jax.Array
p_mute_author_score: jax.Array
p_report_score: jax.Array
p_dwell_time: jax.Array
continuous_preds: Optional[jax.Array] = None
@dataclass
class ModelRunner(BaseModelRunner):
"""Runner for the recommendation ranking model."""
_model: PhoenixModelConfig = None # type: ignore
def __init__(self, model: PhoenixModelConfig, bs_per_device: float = 2.0, rng_seed: int = 42):
self._model = model
self.bs_per_device = bs_per_device
self.rng_seed = rng_seed
@property
def model(self) -> PhoenixModelConfig:
return self._model
@property
def _model_name(self) -> str:
return "ranking model"
def make_forward_fn(self): # type: ignore
def forward(batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings):
out = self.model.make()(batch, recsys_embeddings)
return out
return hk.transform(forward)
def init(
self, rng: jax.Array, data: RecsysBatch, embeddings: RecsysEmbeddings
) -> TrainingState:
assert self.forward is not None
rng, init_rng = jax.random.split(rng)
params = self.forward.init(init_rng, data, embeddings)
return TrainingState(params=params)
def load_or_init(
self,
init_data: RecsysBatch,
init_embeddings: RecsysEmbeddings,
checkpoint_path: Optional[str] = None,
):
if checkpoint_path is not None:
params = load_model_params(checkpoint_path)
return TrainingState(params=params)
rng = jax.random.PRNGKey(self.rng_seed)
state = self.init(rng, init_data, init_embeddings)
return state
@dataclass
class RecsysInferenceRunner(BaseInferenceRunner):
"""Inference runner for the recommendation ranking model."""
_runner: ModelRunner
def __init__(self, runner: ModelRunner, name: str):
self.name = name
self._runner = runner
@property
def runner(self) -> ModelRunner:
return self._runner
def initialize(self):
"""Initialize the inference runner."""
runner = self.runner
dummy_batch = self.create_dummy_batch(batch_size=1)
dummy_embeddings = self.create_dummy_embeddings(batch_size=1)
runner.initialize()
state = runner.load_or_init(dummy_batch, dummy_embeddings)
self.params = state.params
@functools.lru_cache
def model():
return runner.model.make()
def hk_forward(
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> RecsysModelOutput:
return model()(batch, recsys_embeddings)
def hk_rank_candidates(
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> RankingOutput:
"""Rank candidates by their predicted engagement scores."""
output = hk_forward(batch, recsys_embeddings)
logits = output.logits
probs = jax.nn.sigmoid(logits)
primary_scores = probs[:, :, 0]
ranked_indices = jnp.argsort(-primary_scores, axis=-1)
return RankingOutput(
scores=probs,
ranked_indices=ranked_indices,
p_favorite_score=probs[:, :, 0],
p_reply_score=probs[:, :, 1],
p_repost_score=probs[:, :, 2],
p_photo_expand_score=probs[:, :, 3],
p_click_score=probs[:, :, 4],
p_profile_click_score=probs[:, :, 5],
p_vqv_score=probs[:, :, 6],
p_share_score=probs[:, :, 7],
p_share_via_dm_score=probs[:, :, 8],
p_share_via_copy_link_score=probs[:, :, 9],
p_dwell_score=probs[:, :, 10],
p_quote_score=probs[:, :, 11],
p_quoted_click_score=probs[:, :, 12],
p_follow_author_score=probs[:, :, 13],
p_not_interested_score=probs[:, :, 14],
p_block_author_score=probs[:, :, 15],
p_mute_author_score=probs[:, :, 16],
p_report_score=probs[:, :, 17],
p_dwell_time=probs[:, :, 18],
continuous_preds=output.continuous_preds,
)
rank_ = hk.without_apply_rng(hk.transform(hk_rank_candidates))
self.rank_candidates = rank_.apply
def rank(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> RankingOutput:
"""Rank candidates for the given batch.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
RankingOutput with scores and ranked indices
"""
return self.rank_candidates(self.params, batch, recsys_embeddings)
def create_example_batch(
batch_size: int,
emb_size: int,
history_len: int,
num_candidates: int,
num_actions: int,
num_user_hashes: int = 2,
num_item_hashes: int = 2,
num_author_hashes: int = 2,
product_surface_vocab_size: int = 16,
num_user_embeddings: int = 1_000_000,
num_post_embeddings: int = 1_000_000,
num_author_embeddings: int = 1_000_000,
include_continuous_actions: bool = False,
include_timestamps: bool = False,
num_continuous_actions: int = 8,
) -> Tuple[RecsysBatch, RecsysEmbeddings]:
"""Create an example batch with random data for testing.
This simulates a recommendation scenario where:
- We have a user with some embedding
- The user has interacted with some posts in their history
- We want to rank a set of candidate posts
Note on embedding table sizes:
The num_*_embeddings parameters define the size of the embedding tables for each
entity type. Hash values are generated in the range [1, num_*_embeddings) to ensure
they can be used as valid indices into the corresponding embedding tables.
Hash value 0 is reserved for padding/invalid entries.
Returns:
Tuple of (RecsysBatch, RecsysEmbeddings)
"""
rng = np.random.default_rng(42)
user_hashes = rng.integers(1, num_user_embeddings, size=(batch_size, num_user_hashes)).astype(
np.int32
)
history_post_hashes = rng.integers(
1, num_post_embeddings, size=(batch_size, history_len, num_item_hashes)
).astype(np.int32)
for b in range(batch_size):
valid_len = rng.integers(history_len // 2, history_len + 1)
history_post_hashes[b, valid_len:, :] = 0
history_author_hashes = rng.integers(
1, num_author_embeddings, size=(batch_size, history_len, num_author_hashes)
).astype(np.int32)
for b in range(batch_size):
valid_len = rng.integers(history_len // 2, history_len + 1)
history_author_hashes[b, valid_len:, :] = 0
history_actions = (rng.random(size=(batch_size, history_len, num_actions)) > 0.7).astype(
np.float32
)
history_product_surface = rng.integers(
0, product_surface_vocab_size, size=(batch_size, history_len)
).astype(np.int32)
candidate_post_hashes = rng.integers(
1, num_post_embeddings, size=(batch_size, num_candidates, num_item_hashes)
).astype(np.int32)
candidate_author_hashes = rng.integers(
1, num_author_embeddings, size=(batch_size, num_candidates, num_author_hashes)
).astype(np.int32)
candidate_product_surface = rng.integers(
0, product_surface_vocab_size, size=(batch_size, num_candidates)
).astype(np.int32)
history_continuous_actions = None
if include_continuous_actions:
history_continuous_actions = np.zeros(
(batch_size, history_len, num_continuous_actions), dtype=np.float32
)
history_continuous_actions[:, :, 1] = rng.exponential(
scale=10.0, size=(batch_size, history_len)
).astype(np.float32)
candidate_impr_ts = None
candidate_post_creation_ts = None
if include_timestamps:
base_ts = 1700000000
candidate_impr_ts = np.full((batch_size, num_candidates), base_ts, dtype=np.int32)
age_seconds = rng.integers(60, 72 * 3600, size=(batch_size, num_candidates))
candidate_post_creation_ts = (candidate_impr_ts - age_seconds).astype(np.int32)
batch = RecsysBatch(
user_hashes=user_hashes,
history_post_hashes=history_post_hashes,
history_author_hashes=history_author_hashes,
history_actions=history_actions,
history_product_surface=history_product_surface,
candidate_post_hashes=candidate_post_hashes,
candidate_author_hashes=candidate_author_hashes,
candidate_product_surface=candidate_product_surface,
history_continuous_actions=history_continuous_actions,
candidate_impr_ts=candidate_impr_ts,
candidate_post_creation_ts=candidate_post_creation_ts,
)
embeddings = RecsysEmbeddings(
user_embeddings=rng.normal(size=(batch_size, num_user_hashes, emb_size)).astype(np.float32),
history_post_embeddings=rng.normal(
size=(batch_size, history_len, num_item_hashes, emb_size)
).astype(np.float32),
candidate_post_embeddings=rng.normal(
size=(batch_size, num_candidates, num_item_hashes, emb_size)
).astype(np.float32),
history_author_embeddings=rng.normal(
size=(batch_size, history_len, num_author_hashes, emb_size)
).astype(np.float32),
candidate_author_embeddings=rng.normal(
size=(batch_size, num_candidates, num_author_hashes, emb_size)
).astype(np.float32),
)
return batch, embeddings
class RetrievalOutput(NamedTuple):
"""Output from retrieval inference.
Contains user representations and retrieved candidates.
"""
user_representation: jax.Array
top_k_indices: jax.Array
top_k_scores: jax.Array
@dataclass
class RetrievalModelRunner(BaseModelRunner):
"""Runner for the Phoenix retrieval model."""
_model: PhoenixRetrievalModelConfig = None # type: ignore
def __init__(
self,
model: PhoenixRetrievalModelConfig,
bs_per_device: float = 2.0,
rng_seed: int = 42,
):
self._model = model
self.bs_per_device = bs_per_device
self.rng_seed = rng_seed
@property
def model(self) -> PhoenixRetrievalModelConfig:
return self._model
@property
def _model_name(self) -> str:
return "retrieval model"
def make_forward_fn(self): # type: ignore
def forward(
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
) -> ModelRetrievalOutput:
model = self.model.make()
out = model(batch, recsys_embeddings, corpus_embeddings, top_k)
_ = model.build_candidate_representation(batch, recsys_embeddings)
return out
return hk.transform(forward)
def init(
self,
rng: jax.Array,
data: RecsysBatch,
embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
) -> TrainingState:
assert self.forward is not None
rng, init_rng = jax.random.split(rng)
params = self.forward.init(init_rng, data, embeddings, corpus_embeddings, top_k)
return TrainingState(params=params)
def load_or_init(
self,
init_data: RecsysBatch,
init_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
):
rng = jax.random.PRNGKey(self.rng_seed)
state = self.init(rng, init_data, init_embeddings, corpus_embeddings, top_k)
return state
@dataclass
class RecsysRetrievalInferenceRunner(BaseInferenceRunner):
"""Inference runner for the Phoenix retrieval model.
This runner provides methods for:
1. Encoding users to get user representations
2. Encoding candidates to get candidate embeddings
3. Retrieving top-k candidates from a corpus
"""
_runner: RetrievalModelRunner = None # type: ignore
corpus_embeddings: jax.Array | None = None
corpus_post_ids: jax.Array | None = None
def __init__(self, runner: RetrievalModelRunner, name: str):
self.name = name
self._runner = runner
self.corpus_embeddings = None
self.corpus_post_ids = None
@property
def runner(self) -> RetrievalModelRunner:
return self._runner
def initialize(self):
"""Initialize the retrieval inference runner."""
runner = self.runner
dummy_batch = self.create_dummy_batch(batch_size=1)
dummy_embeddings = self.create_dummy_embeddings(batch_size=1)
dummy_corpus = jnp.zeros((10, runner.model.emb_size), dtype=jnp.float32)
dummy_top_k = 5
runner.initialize()
state = runner.load_or_init(dummy_batch, dummy_embeddings, dummy_corpus, dummy_top_k)
self.params = state.params
@functools.lru_cache
def model():
return runner.model.make()
def hk_encode_user(batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> jax.Array:
"""Encode user to get user representation."""
m = model()
user_rep, _ = m.build_user_representation(batch, recsys_embeddings)
return user_rep
def hk_encode_candidates(
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> jax.Array:
"""Encode candidates to get candidate representations."""
m = model()
cand_rep, _ = m.build_candidate_representation(batch, recsys_embeddings)
return cand_rep
def hk_retrieve(
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
) -> "RetrievalOutput":
"""Retrieve top-k candidates from corpus."""
m = model()
return m(batch, recsys_embeddings, corpus_embeddings, top_k)
encode_user_ = hk.without_apply_rng(hk.transform(hk_encode_user))
encode_candidates_ = hk.without_apply_rng(hk.transform(hk_encode_candidates))
retrieve_ = hk.without_apply_rng(hk.transform(hk_retrieve))
self.encode_user_fn = encode_user_.apply
self.encode_candidates_fn = encode_candidates_.apply
self.retrieve_fn = retrieve_.apply
def encode_user(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> jax.Array:
"""Encode users to get user representations.
Args:
batch: RecsysBatch containing user and history information
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
User representations [B, D]
"""
return self.encode_user_fn(self.params, batch, recsys_embeddings)
def encode_candidates(
self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> jax.Array:
"""Encode candidates to get candidate representations.
Args:
batch: RecsysBatch containing candidate information
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
Candidate representations [B, C, D]
"""
return self.encode_candidates_fn(self.params, batch, recsys_embeddings)
def set_corpus(
self,
corpus_embeddings: jax.Array,
corpus_post_ids: jax.Array,
):
"""Set the corpus embeddings for retrieval.
Args:
corpus_embeddings: Pre-computed candidate embeddings [N, D]
corpus_post_ids: Optional post IDs corresponding to embeddings [N]
"""
self.corpus_embeddings = corpus_embeddings
self.corpus_post_ids = corpus_post_ids
def retrieve(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
top_k: int = 100,
corpus_embeddings: Optional[jax.Array] = None,
) -> RetrievalOutput:
"""Retrieve top-k candidates for users.
Args:
batch: RecsysBatch containing user and history information
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
top_k: Number of candidates to retrieve per user
corpus_embeddings: Optional corpus embeddings (uses set_corpus if not provided)
Returns:
RetrievalOutput with user representations and top-k candidates
"""
if corpus_embeddings is None:
corpus_embeddings = self.corpus_embeddings
return self.retrieve_fn(self.params, batch, recsys_embeddings, corpus_embeddings, top_k)
def create_example_corpus(
corpus_size: int,
emb_size: int,
seed: int = 123,
) -> Tuple[jax.Array, jax.Array]:
"""Create example corpus embeddings for testing retrieval.
Args:
corpus_size: Number of candidates in corpus
emb_size: Embedding dimension
seed: Random seed
Returns:
Tuple of (corpus_embeddings [N, D], corpus_post_ids [N])
"""
rng = np.random.default_rng(seed)
corpus_embeddings = rng.normal(size=(corpus_size, emb_size)).astype(np.float32)
norms = np.linalg.norm(corpus_embeddings, axis=-1, keepdims=True)
corpus_embeddings = corpus_embeddings / np.maximum(norms, 1e-12)
corpus_post_ids = np.arange(corpus_size, dtype=np.int64)
return jnp.array(corpus_embeddings), jnp.array(corpus_post_ids)