|
| 1 | +# Copyright 2024 RecML authors <recommendations-ml@google.com>. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import jax |
| 15 | +import jax.numpy as jnp |
| 16 | +import numpy as np |
| 17 | +import numpy.testing as npt |
| 18 | +# from third_party.py.pybase import googletest |
| 19 | +from absl.testing import absltest |
| 20 | +from recml.examples.DLRM_HSTU.action_encoder import ActionEncoder |
| 21 | + |
| 22 | + |
| 23 | +class ActionEncoderJaxTest(absltest.TestCase): |
| 24 | + def test_forward_and_backward(self) -> None: |
| 25 | + """Tests the ActionEncoder's forward pass logic and differentiability.""" |
| 26 | + |
| 27 | + batch_size = 2 |
| 28 | + max_seq_len = 6 |
| 29 | + action_embedding_dim = 32 |
| 30 | + action_weights = [1, 2, 4, 8, 16] |
| 31 | + watchtime_to_action_thresholds_and_weights = [ |
| 32 | + (30, 32), (60, 64), (100, 128), |
| 33 | + ] |
| 34 | + num_action_types = len(action_weights) + len( |
| 35 | + watchtime_to_action_thresholds_and_weights |
| 36 | + ) |
| 37 | + output_dim = action_embedding_dim * num_action_types |
| 38 | + combined_action_weights = action_weights + [ |
| 39 | + w for _, w in watchtime_to_action_thresholds_and_weights |
| 40 | + ] |
| 41 | + |
| 42 | + enabled_actions = [ |
| 43 | + [0], # Seq 1, Item 1 |
| 44 | + [0, 1], # Seq 1, Item 2 |
| 45 | + [1, 3, 4], # Seq 1, Item 3 |
| 46 | + [1, 2, 3, 4], # Seq 1, Item 4 |
| 47 | + [1, 2], # Seq 2, Item 1 |
| 48 | + [2], # Seq 2, Item 2 |
| 49 | + ] |
| 50 | + watchtimes_flat = [40, 20, 110, 31, 26, 55] |
| 51 | + |
| 52 | + # Add actions based on watchtime thresholds |
| 53 | + for i, wt in enumerate(watchtimes_flat): |
| 54 | + for j, (threshold, _) in enumerate( |
| 55 | + watchtime_to_action_thresholds_and_weights |
| 56 | + ): |
| 57 | + if wt > threshold: |
| 58 | + enabled_actions[i].append(j + len(action_weights)) |
| 59 | + |
| 60 | + actions_flat = [ |
| 61 | + sum([combined_action_weights[t] for t in x]) for x in enabled_actions |
| 62 | + ] |
| 63 | + |
| 64 | + padded_actions = np.zeros((batch_size, max_seq_len), dtype=np.int64) |
| 65 | + padded_watchtimes = np.zeros((batch_size, max_seq_len), dtype=np.int64) |
| 66 | + |
| 67 | + padded_actions[0, :4] = actions_flat[0:4] |
| 68 | + padded_actions[1, :2] = actions_flat[4:6] |
| 69 | + padded_watchtimes[0, :4] = watchtimes_flat[0:4] |
| 70 | + padded_watchtimes[1, :2] = watchtimes_flat[4:6] |
| 71 | + |
| 72 | + is_target_mask = np.zeros((batch_size, max_seq_len), dtype=bool) |
| 73 | + is_target_mask[0, 4:6] = True |
| 74 | + is_target_mask[1, 2] = True |
| 75 | + |
| 76 | + padding_mask = np.zeros((batch_size, max_seq_len), dtype=bool) |
| 77 | + padding_mask[0, :6] = True |
| 78 | + padding_mask[1, :3] = True |
| 79 | + |
| 80 | + seq_payloads = { |
| 81 | + "watchtimes": jnp.array(padded_watchtimes), |
| 82 | + "actions": jnp.array(padded_actions), |
| 83 | + } |
| 84 | + |
| 85 | + encoder = ActionEncoder( |
| 86 | + watchtime_feature_name="watchtimes", |
| 87 | + action_feature_name="actions", |
| 88 | + action_weights=action_weights, |
| 89 | + watchtime_to_action_thresholds_and_weights=( |
| 90 | + watchtime_to_action_thresholds_and_weights |
| 91 | + ), |
| 92 | + action_embedding_dim=action_embedding_dim, |
| 93 | + ) |
| 94 | + |
| 95 | + key = jax.random.PRNGKey(0) |
| 96 | + variables = encoder.init(key, seq_payloads, is_target_mask) |
| 97 | + params = variables["params"] |
| 98 | + |
| 99 | + action_embeddings = encoder.apply( |
| 100 | + variables, seq_payloads, is_target_mask |
| 101 | + ) |
| 102 | + |
| 103 | + self.assertEqual( |
| 104 | + action_embeddings.shape, (batch_size, max_seq_len, output_dim) |
| 105 | + ) |
| 106 | + |
| 107 | + action_table = params["action_embedding_table"] |
| 108 | + target_table_flat = params["target_action_embedding_table"] |
| 109 | + target_table = target_table_flat.reshape(num_action_types, -1) |
| 110 | + |
| 111 | + history_item_idx = 0 |
| 112 | + for b in range(batch_size): |
| 113 | + for s in range(max_seq_len): |
| 114 | + if not padding_mask[b, s]: |
| 115 | + npt.assert_allclose(action_embeddings[b, s], 0, atol=1e-6) |
| 116 | + continue |
| 117 | + |
| 118 | + embedding = action_embeddings[b, s].reshape(num_action_types, -1) |
| 119 | + |
| 120 | + if is_target_mask[b, s]: |
| 121 | + npt.assert_allclose(embedding, target_table, atol=1e-6) |
| 122 | + else: |
| 123 | + current_enabled = enabled_actions[history_item_idx] |
| 124 | + for atype in range(num_action_types): |
| 125 | + if atype in current_enabled: |
| 126 | + npt.assert_allclose( |
| 127 | + embedding[atype], action_table[atype], atol=1e-6 |
| 128 | + ) |
| 129 | + else: |
| 130 | + npt.assert_allclose(embedding[atype], |
| 131 | + jnp.zeros_like(embedding[atype]), |
| 132 | + atol=1e-6) |
| 133 | + history_item_idx += 1 |
| 134 | + |
| 135 | + def loss_fn(p): |
| 136 | + return encoder.apply({"params": p}, seq_payloads, is_target_mask).sum() |
| 137 | + |
| 138 | + grads = jax.grad(loss_fn)(params) |
| 139 | + self.assertIsNotNone(grads) |
| 140 | + self.assertFalse(np.all(np.isclose(grads["action_embedding_table"], 0))) |
| 141 | + self.assertFalse(np.all( |
| 142 | + np.isclose(grads["target_action_embedding_table"], 0) |
| 143 | + )) |
| 144 | + |
| 145 | + |
| 146 | +if __name__ == "__main__": |
| 147 | + absltest.main() |
0 commit comments