diff --git a/tzrec/utils/dynamicemb_util.py b/tzrec/utils/dynamicemb_util.py index 09e9bf2b..3eeb9c88 100644 --- a/tzrec/utils/dynamicemb_util.py +++ b/tzrec/utils/dynamicemb_util.py @@ -9,9 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import math import os -from typing import List, Optional, Tuple, Type, cast +from typing import Any, List, Optional, Tuple, Type, cast import torch from torch import nn @@ -21,7 +22,10 @@ planners, shard_estimators, ) -from torchrec.distributed.planner.estimator.types import HardwarePerfConfig +from torchrec.distributed.planner.estimator.types import ( + HardwarePerfConfig, + ShardPerfContext, +) from torchrec.distributed.planner.types import ( ParameterConstraints, ShardingOption, @@ -43,6 +47,89 @@ from torchrec.modules.embedding_configs import BaseEmbeddingConfig from tzrec.protos import feature_pb2 +from tzrec.utils.logging_util import logger + +_DYNAMICEMB_CACHING_X_EFF_BASE = 0.28 +_DYNAMICEMB_HYBRID_X_EFF_BASE = 0.11 +_DYNAMICEMB_X_EFF_TIEBREAK = 0.01 + + +def _dynamicemb_effective_cache_ratio( + cache_load_factor: Optional[float], + caching: bool, + stats: Optional[Any] = None, +) -> float: + """Effective HBM-hit ratio for the dynamicemb perf model. + + Returns the value passed to torchrec's perf bandwidth formula + ``bw = x_eff*hbm + (1-x_eff)*hbm_to_ddr_bw``. Larger value = faster path. + + The ratio is derived from an on-device perf sweep, not a heuristic. + Empirical pattern (alpha=1.05 pow-law access on A10): + + * ``x == 1.0``: the runtime *switches kernels* — when + ``total_value_memory <= local_hbm_for_values`` the dual-tier + ``HybridStorage`` / ``DynamicEmbCache`` paths are dropped in favor + of the HBM-only ``DynamicEmbStorage`` kernel + (``batched_dynamicemb_tables.py:502-604``). The ~8x jump in ``x_eff`` + between ``x=0.9`` and ``x=1.0`` is intentional and matches measured + latency, not a smoothing artifact. (A future refactor could lift + this to a discrete ``mode={HBM_ONLY, HYBRID, CACHING}`` axis on the + enumerator side rather than packing the discontinuity into ``x``.) + * ``caching=True``, ``x < 1.0``: 3.3x slower than HBM-only -> base 0.28. + * ``caching=False``, ``x < 1.0``: 6.8x slower than HBM-only -> base 0.11. + + Within each ``x < 1.0`` block the perf is roughly flat in ratio, but we + add a tiny monotonic perturbation so the DP can break ties. + + If ``stats`` is provided, ``1 - stats.expected_miss_rate(x)`` overrides + the heuristic verbatim (clamped to ``[0, 1]``); the caller opts in to + their own measurement. + """ + x = float(cache_load_factor) if cache_load_factor is not None else 0.0 + x = max(0.0, min(1.0, x)) + if stats is not None: + miss_rate = float(stats.expected_miss_rate(x)) + return max(0.0, min(1.0, 1.0 - miss_rate)) + if x >= 1.0: + return 1.0 + base = _DYNAMICEMB_CACHING_X_EFF_BASE if caching else _DYNAMICEMB_HYBRID_X_EFF_BASE + return base + _DYNAMICEMB_X_EFF_TIEBREAK * x + + +def _log_dynamicemb_table_plan( + *, + fqn: str, + cache_load_factor: float, + caching: bool, + hbm_bytes: int, + ddr_bytes: int, +) -> None: + """Per-table mode log on rank 0. + + cache_load_factor=1.0 forces the runtime into HBM_ONLY (host tier + dropped) regardless of ``caching`` -- mirror that override here so + the log matches the runtime, not the planner's recorded + (caching, factor). Use exact equality so a stray >1.0 fails loud + downstream instead of being silently relabelled HBM_ONLY. + """ + if int(os.environ.get("RANK", 0)) != 0: + return + if cache_load_factor == 1.0: + mode = "HBM_ONLY" + dram_bytes = 0 # runtime drops the host tier + else: + mode = "CACHING" if caching else "HYBRID" + dram_bytes = ddr_bytes + hbm_gib = hbm_bytes / (1 << 30) + dram_gib = dram_bytes / (1 << 30) + logger.info( + f"[dynamicemb plan] {fqn}: mode={mode} " + f"cache_load_factor={cache_load_factor:.2f} " + f"local_hbm={hbm_gib:.3f}GiB " + f"local_dram={dram_gib:.3f}GiB" + ) + has_dynamicemb = False try: @@ -258,18 +345,36 @@ def _calculate_dynamicemb_table_storage_specific_size( is_hbm: bool = True, only_values: bool = False, bucket_capacity: int = 128, + caching: bool = False, ) -> int: - """Calculate dynamic embedding table storage. - - total_value_memory = max_capacity x aligned16(embedding+optimizer states) - num_buckets = max_capacity/bucket_capacity - hbm_budget = min(global_hbm_for_values//world_size, total_value_memory) + - max_capacity x (key<8byte> + score<8byte> + digest<1byte>) + - num_buckets x (bucket_size<4byte>) - ddr_budget = max(total_value_memory - global_hbm_for_values//world_size, 0) + """Per-shard storage size for a dynamicemb table -- HBM or DDR (bytes). + + Byte budget (single shard, rows x dim): + + value_bytes_per_row = round_up16(dim * (1 + opt_mult) * element) + total_value_memory = align(rows) * value_bytes_per_row + num_buckets = align(rows) / bucket_capacity + + hbm_budget = cache_ratio * total_value_memory # values + + align(rows) * (key<8B> + score<8B> + digest<1B>) # per-row + + num_buckets * bucket_header<4B> # per-bucket + + ddr_budget = HYBRID (caching=False): (1 - cache_ratio) * total_value_memory + CACHING (caching=True): total_value_memory # full backing + + HYBRID hash-partitions values across HBM and host; ``cache_ratio`` is + HBM's value share. CACHING keeps the full backing store on host and + uses HBM as a hot-row cache of size + ``cache_ratio * total_value_memory``. Hash-table metadata + (key + score + digest + bucket header) is accounted on HBM only -- + matches the existing tzrec convention. """ if cache_ratio is None: cache_ratio = 1.0 + if is_hbm: + value_ratio = cache_ratio + else: + value_ratio = 1.0 if caching else (1.0 - cache_ratio) return math.ceil( align_to_table_size(size[0]) * ( @@ -277,7 +382,7 @@ def _calculate_dynamicemb_table_storage_specific_size( math.ceil(size[1] * (1 + optimizer_multipler) * element_size), 16, ) - * (cache_ratio if is_hbm else 1 - cache_ratio) + * value_ratio + (8 + 8 + 1 + 4 / bucket_capacity) * (is_hbm and not only_values) ) ) @@ -363,6 +468,13 @@ def _to_sharding_plan( dist_type="roundrobin", dynamicemb_options=dynamicemb_options, ) + _log_dynamicemb_table_plan( + fqn=f"{sharding_option.path}.{sharding_option.name}", + cache_load_factor=float(sharding_option.cache_load_factor), + caching=bool(dynamicemb_options.caching), + hbm_bytes=int(shards[0].storage.hbm), + ddr_bytes=int(shards[0].storage.ddr), + ) else: module_plan[sharding_option.name] = ParameterSharding( sharding_spec=sharding_spec, @@ -413,6 +525,66 @@ def _customized_kernel_aware_get_device_bw( # pyre-ignore [9] HardwarePerfConfig.get_device_bw = _customized_kernel_aware_get_device_bw + _orig_build_shard_perf_contexts = ( + ShardPerfContext.build_shard_perf_contexts.__func__ + ) + + def _dynamicemb_aware_build_shard_perf_contexts( + cls, # pyre-ignore [2] + config, # pyre-ignore [2] + shard_sizes, # pyre-ignore [2] + sharding_option, # pyre-ignore [2] + topology, # pyre-ignore [2] + constraints, # pyre-ignore [2] + sharder, # pyre-ignore [2] + *args, # pyre-ignore [2] + **kwargs, # pyre-ignore [2] + ): + """Inject the empirical x_eff into the perf estimator for both modes. + + Temporarily replace ``sharding_option.cache_params`` with a clone + whose ``load_factor`` is the empirically-fitted x_eff for the + (mode, cache_load_factor) combination. Restored before returning so + the (separately invoked) storage estimator still sees the un-boosted + ratio. + """ + dynamicemb_options = getattr(sharding_option, "dynamicemb_options", None) + original_cache_params = sharding_option.cache_params + if dynamicemb_options is not None: + caching = bool(getattr(dynamicemb_options, "caching", False)) + stats = original_cache_params.stats if original_cache_params else None + x_eff = _dynamicemb_effective_cache_ratio( + sharding_option.cache_load_factor, caching=caching, stats=stats + ) + sharding_option.cache_params = ( + dataclasses.replace(original_cache_params, load_factor=x_eff) + if original_cache_params is not None + else CacheParams(load_factor=x_eff) + ) + # try/finally so an estimator exception cannot leak the boosted + # cache_params clone into the storage estimator's view of the + # same ShardingOption. + try: + result = _orig_build_shard_perf_contexts( + cls, + config, + shard_sizes, + sharding_option, + topology, + constraints, + sharder, + *args, + **kwargs, + ) + finally: + sharding_option.cache_params = original_cache_params + return result + + # pyre-ignore [9] + ShardPerfContext.build_shard_perf_contexts = classmethod( + _dynamicemb_aware_build_shard_perf_contexts + ) + def _calculate_dynamicemb_storage_specific_sizes( tensor: torch.Tensor, shard_sizes: List[List[int]], @@ -420,6 +592,7 @@ def _calculate_dynamicemb_storage_specific_sizes( cache_ratio: float = 1.0, is_inference: bool = False, bucket_capacity: int = 128, + caching: bool = False, ) -> Tuple[List[int], List[int]]: """Calculate storage for dynamicemb.""" optimizer_multipler = 0.0 @@ -437,6 +610,7 @@ def _calculate_dynamicemb_storage_specific_sizes( cache_ratio, is_hbm=True, bucket_capacity=bucket_capacity, + caching=caching, ) for size in shard_sizes ] @@ -449,6 +623,7 @@ def _calculate_dynamicemb_storage_specific_sizes( cache_ratio, is_hbm=False, bucket_capacity=bucket_capacity, + caching=caching, ) for size in shard_sizes ] @@ -496,7 +671,10 @@ def dynamicemb_calculate_shard_storages( factors. num_poolings (List[float]): average number of poolings per sample (typically 1.0). - caching_ratio (float): ratio of HBM to DDR memory for UVM caching. + caching_ratio (float): cache_load_factor for the dynamicemb table. + In HYBRID mode HBM holds this fraction of values and host + holds the remainder; in CACHING mode HBM is a hot-row cache + of this fraction and host holds the full backing store. is_pooled (bool): True if embedding output is pooled (ie. `EmbeddingBag`), False if unpooled/sequential (ie. `Embedding`). input_data_type_size (int): number of bytes of input data type. @@ -535,6 +713,7 @@ def dynamicemb_calculate_shard_storages( cache_ratio=caching_ratio if caching_ratio else 1.0, is_inference=is_inference, bucket_capacity=dynamicemb_options.bucket_capacity, + caching=bool(getattr(dynamicemb_options, "caching", False)), ) ) counter_hbm_specific_size = 0 diff --git a/tzrec/utils/dynamicemb_util_test.py b/tzrec/utils/dynamicemb_util_test.py new file mode 100644 index 00000000..8d8cc2f3 --- /dev/null +++ b/tzrec/utils/dynamicemb_util_test.py @@ -0,0 +1,110 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from parameterized import parameterized + +from tzrec.utils import dynamicemb_util + + +@unittest.skipUnless( + dynamicemb_util.has_dynamicemb, "dynamicemb is not installed; skipping." +) +class StorageFormulaTest(unittest.TestCase): + """Mode-aware ``_calculate_dynamicemb_table_storage_specific_size``.""" + + ROWS = 1024 + DIM = 64 + ELEMENT_SIZE = 4 + BUCKET_CAPACITY = 128 + + def _calc(self, *, cache_ratio, is_hbm, caching, only_values=False): + return dynamicemb_util._calculate_dynamicemb_table_storage_specific_size( + size=[self.ROWS, self.DIM], + element_size=self.ELEMENT_SIZE, + cache_ratio=cache_ratio, + is_hbm=is_hbm, + only_values=only_values, + bucket_capacity=self.BUCKET_CAPACITY, + caching=caching, + ) + + @parameterized.expand( + [ + ("ratio_0_0", 0.0), + ("ratio_0_25", 0.25), + ("ratio_0_5", 0.5), + ("ratio_0_75", 0.75), + ("ratio_1_0", 1.0), + ] + ) + def test_hbm_identical_between_modes(self, _name, cache_ratio): + # HBM accounting is the same in HYBRID and CACHING: HBM holds a + # cache_ratio fraction of values plus full-row-count metadata. + hybrid_hbm = self._calc(cache_ratio=cache_ratio, is_hbm=True, caching=False) + caching_hbm = self._calc(cache_ratio=cache_ratio, is_hbm=True, caching=True) + self.assertEqual(hybrid_hbm, caching_hbm) + + @parameterized.expand( + [ + ("ratio_0_0", 0.0), + ("ratio_0_25", 0.25), + ("ratio_0_5", 0.5), + ("ratio_0_75", 0.75), + ("ratio_1_0", 1.0), + ] + ) + def test_ddr_hybrid_complements_cache(self, _name, cache_ratio): + # HYBRID DDR = (1 - cache_ratio) * full-table DDR. + full_ddr = self._calc(cache_ratio=0.0, is_hbm=False, caching=False) + hybrid_ddr = self._calc(cache_ratio=cache_ratio, is_hbm=False, caching=False) + self.assertEqual(hybrid_ddr, round((1.0 - cache_ratio) * full_ddr)) + + @parameterized.expand( + [ + ("ratio_0_0", 0.0), + ("ratio_0_25", 0.25), + ("ratio_0_5", 0.5), + ("ratio_0_75", 0.75), + ("ratio_1_0", 1.0), + ] + ) + def test_ddr_caching_holds_full_table(self, _name, cache_ratio): + # CACHING DDR is the full backing store, independent of cache_ratio. + full_ddr = self._calc(cache_ratio=0.0, is_hbm=False, caching=False) + caching_ddr = self._calc(cache_ratio=cache_ratio, is_hbm=False, caching=True) + self.assertEqual(caching_ddr, full_ddr) + + def test_caching_ddr_strictly_greater_than_hybrid_when_cached(self): + for cache_ratio in (0.1, 0.5, 0.9): + hybrid_ddr = self._calc( + cache_ratio=cache_ratio, is_hbm=False, caching=False + ) + caching_ddr = self._calc( + cache_ratio=cache_ratio, is_hbm=False, caching=True + ) + self.assertGreater(caching_ddr, hybrid_ddr) + + def test_only_values_drops_metadata(self): + # only_values=True strips HBM metadata regardless of mode. + for caching in (False, True): + with_meta = self._calc( + cache_ratio=0.5, is_hbm=True, caching=caching, only_values=False + ) + without_meta = self._calc( + cache_ratio=0.5, is_hbm=True, caching=caching, only_values=True + ) + self.assertGreater(with_meta, without_meta) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/utils/plan_util.py b/tzrec/utils/plan_util.py index bf5c69e0..29ec26a1 100644 --- a/tzrec/utils/plan_util.py +++ b/tzrec/utils/plan_util.py @@ -16,6 +16,7 @@ from queue import Queue from typing import Any, Dict, List, Optional, Tuple, Union, cast +import numpy as np import psutil import torch from torch import distributed as dist @@ -87,10 +88,6 @@ from tzrec.utils.logging_util import logger -def _bytes_to_float_bin(num_bytes: Union[float, int], bin_size: float) -> float: - return float(num_bytes) / bin_size - - def create_planner( device: torch.device, batch_size: int, @@ -190,7 +187,18 @@ def create_planner( global_constraints=global_constraints, ), storage_reservation=storage_reservation, - proposer=[DynamicProgrammingProposer(), UniformProposer()], + # DP bin counts are env-tunable; defaults match the proposer signature. + proposer=[ + DynamicProgrammingProposer( + hbm_bins_per_device=int( + os.environ.get("TZREC_DP_HBM_BINS_PER_DEVICE", "100") + ), + ddr_bins_per_device=int( + os.environ.get("TZREC_DP_DDR_BINS_PER_DEVICE", "25") + ), + ), + UniformProposer(), + ], debug=True, ) return planner @@ -224,70 +232,202 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]: return sharders +_INF = float("inf") + + +def _argmin_per_group(group_keys: np.ndarray, perf_keys: np.ndarray) -> np.ndarray: + """Index of the min-``perf_keys`` entry per distinct ``group_keys`` value. + + Both inputs must be 1-D and the same length. Ties on ``perf_keys`` are + broken by input order. Output indexes into the original (pre-sort) + arrays, ordered by ascending ``group_keys``. + """ + if group_keys.size == 0: + return np.empty(0, dtype=np.int64) + lexsort_order = np.lexsort((perf_keys, group_keys)) + sorted_groups = group_keys[lexsort_order] + is_first_of_run = np.empty(sorted_groups.shape, dtype=bool) + is_first_of_run[0] = True + is_first_of_run[1:] = sorted_groups[1:] != sorted_groups[:-1] + return lexsort_order[is_first_of_run] + + +def _sparse_dp_proposor_numpy( + table_opts: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]], + hbm_bins: int, + ddr_bins: int, +) -> List[List[int]]: + """Sparse-K NumPy 2D DP over (hbm_bin, ddr_bin) reachable cells. + + Per table, broadcasts (K_prev, N) candidates and uses :func:`_argmin_per_group` + (lexsort + first-of-run) for groupby-argmin on the destination cell. + Reachable cells K_t saturate at the actual reachable count, not + ``hbm_bins * ddr_bins``. Backtracking walks T x K_t int32 arrays in + O(T) Python steps. + + ``table_opts`` entries are ``(opt_hbm, opt_ddr, opt_perf, opt_global_id)`` + per table, all 1-D NumPy arrays (the float ones in bin-units; the index + one in opt-id space, used to reconstruct proposals). + + Returns a list of proposals; each proposal is a list of opt-id integers, + one per table, emitted in decreasing-HBM order with at most one entry per + HBM bin (the perf-best plan across all DDR bins at that HBM level). + """ + table_count = len(table_opts) + if table_count == 0: + return [] + + seed_hbm, seed_ddr, seed_perf, seed_opt_id = table_opts[0] + valid_mask = (seed_hbm < hbm_bins) & (seed_ddr < ddr_bins) + if not valid_mask.any(): + return [] + seed_hbm = seed_hbm[valid_mask] + seed_ddr = seed_ddr[valid_mask] + seed_perf = seed_perf[valid_mask] + seed_opt_id = seed_opt_id[valid_mask] + seed_hbm_i = seed_hbm.astype(np.int32) + seed_ddr_i = seed_ddr.astype(np.int32) + flat_cell_i = seed_hbm_i.astype(np.int64) * ddr_bins + seed_ddr_i + winners = _argmin_per_group(group_keys=flat_cell_i, perf_keys=seed_perf) + + dp_perf = seed_perf[winners] + dp_hbm = seed_hbm[winners] + dp_ddr = seed_ddr[winners] + dp_hbm_i = seed_hbm_i[winners] + back_opt_j: List[np.ndarray] = [seed_opt_id[winners]] + back_prev_cell_i: List[np.ndarray] = [np.full(winners.size, -1, dtype=np.int32)] + + for table_i in range(1, table_count): + if dp_perf.size == 0: + break + cur_table_hbm, cur_table_ddr, cur_table_perf, cur_table_opt_id = table_opts[ + table_i + ] + if cur_table_perf.size == 0: + dp_perf = np.zeros(0) + break + + new_hbm = dp_hbm[:, None] + cur_table_hbm[None, :] + new_ddr = dp_ddr[:, None] + cur_table_ddr[None, :] + new_perf = dp_perf[:, None] + cur_table_perf[None, :] + valid_mask = (new_hbm < hbm_bins) & (new_ddr < ddr_bins) + if not valid_mask.any(): + dp_perf = np.zeros(0) + break + + new_hbm_i = new_hbm.astype(np.int32) + new_ddr_i = new_ddr.astype(np.int32) + flat_cell_i = new_hbm_i.astype(np.int64) * ddr_bins + new_ddr_i + + valid_flat_cell_i = flat_cell_i[valid_mask] + valid_new_perf = new_perf[valid_mask] + valid_new_hbm = new_hbm[valid_mask] + valid_new_ddr = new_ddr[valid_mask] + valid_new_hbm_i = new_hbm_i[valid_mask] + valid_prev_cell_i, valid_opt_j = np.where(valid_mask) + + winners = _argmin_per_group( + group_keys=valid_flat_cell_i, perf_keys=valid_new_perf + ) + + dp_perf = valid_new_perf[winners] + dp_hbm = valid_new_hbm[winners] + dp_ddr = valid_new_ddr[winners] + dp_hbm_i = valid_new_hbm_i[winners] + back_opt_j.append(cur_table_opt_id[valid_opt_j[winners]]) + back_prev_cell_i.append(valid_prev_cell_i[winners].astype(np.int32)) + + if dp_perf.size == 0: + return [] + + # Per-HBM-bin best, decreasing HBM order. + chosen_cell_i = _argmin_per_group(group_keys=dp_hbm_i, perf_keys=dp_perf)[::-1] + + proposals: List[List[int]] = [] + for last_cell_i in chosen_cell_i: + proposal_indices = [-1] * table_count + cur_cell_i = int(last_cell_i) + for table_i in range(table_count - 1, -1, -1): + proposal_indices[table_i] = int(back_opt_j[table_i][cur_cell_i]) + cur_cell_i = int(back_prev_cell_i[table_i][cur_cell_i]) + proposals.append(proposal_indices) + return proposals + + class DynamicProgrammingProposer(Proposer): - r"""Proposes sharding plans in dynamic programming fashion. + r"""Proposes sharding plans in 2D (HBM × DDR) dynamic programming fashion. The problem of the Embedding Sharding Plan can be framed as follows: Given :math:`M` tables and their corresponding :math:`N` Sharding Options, we need to select one sharding option for each table such that the total performance is - minimized, while keeping the overall memory constraint :math:`K` in check. This can - be abstracted into the following mathematical formulation: + minimized, while keeping both an HBM constraint :math:`K_h` and a host DDR + constraint :math:`K_d` in check. This can be abstracted into the following + mathematical formulation: - Given a matrix :math:`A` of dimensions :math:`(M, N)` and another matrix :math:`B` - of the same dimensions, let the elements of matrix :math:`A` be denoted as - :math:`a_{i,j}` and the elements of matrix :math:`B` as :math:`b_{i,j}`. We aim - to find a set of column indices :math:`\{ j_0, j_1, \ldots, j_{M-1} \}` such that - the following conditions are satisfied: + Given matrices :math:`A^h`, :math:`A^d`, and :math:`B` of dimensions + :math:`(M, N)`, let :math:`a^h_{i,j}` and :math:`a^d_{i,j}` be the per-option + HBM and DDR storage costs, and :math:`b_{i,j}` the perf cost. We aim to find a + set of column indices :math:`\{ j_0, j_1, \ldots, j_{M-1} \}` such that the + following conditions are satisfied: - 1. :math:`\sum_{i=0}^{M-1} a_{i,j_i} \leq K`, where :math:`K` is a float. - 2. :math:`\sum_{i=0}^{M-1} b_{i,j_i}` is minimized. + 1. :math:`\sum_{i=0}^{M-1} a^h_{i,j_i} \leq K_h`. + 2. :math:`\sum_{i=0}^{M-1} a^d_{i,j_i} \leq K_d`. + 3. :math:`\sum_{i=0}^{M-1} b_{i,j_i}` is minimized. - This problem can be tackled using dynamic programming. First, discretize :math:`K` - into :math:`K_i`, and denote the discretization function as :math:`f`. + This problem can be tackled using 2D dynamic programming. First, discretize + :math:`K_h` and :math:`K_d` into bins, and denote the discretization functions + as :math:`f_h` and :math:`f_d`. - Define the state :math:`dp[i][f(k)]` to represent the minimum value of :math:`B` - when considering the first :math:`i` rows and the total sum of :math:`A` is equal to - the discretized value :math:`k`. + Define the state :math:`dp[i][f_h(k_h)][f_d(k_d)]` to represent the minimum + value of :math:`B` when considering the first :math:`i` rows and the totals of + :math:`A^h` and :math:`A^d` equal the discretized values :math:`k_h` and + :math:`k_d` respectively. The state transition can then be represented as: .. math:: - dp[i][f(k)] = \min_{j=0}^{N-1} \left( dp[i-1][f(k - A[i][j])] + B[i][j] \right) + dp[i][f_h(k_h)][f_d(k_d)] = \min_{j=0}^{N-1} \left( + dp[i-1][f_h(k_h - a^h_{i,j})][f_d(k_d - a^d_{i,j})] + b_{i,j} \right) - Since :math:`K` is the sum allocated across all memory, simply satisfying that the - total memory in the plan equals :math:`K` does not guarantee that the allocation - will fit on all cards. Therefore, it is essential to maintain all the states of the - last layer of :math:`dp`. This allows us to propose different plans under varying - total memory constraints. + Since :math:`K_h` and :math:`K_d` are sums allocated across all memory, simply + satisfying that the totals in the plan equal them does not guarantee that the + allocation will fit on all cards / hosts. Therefore, it is essential to + maintain all the states of the last layer of :math:`dp`. For each HBM bin we + emit one proposal -- the lowest-:math:`B` plan across all DDR bins at that + HBM level -- in decreasing HBM order; plans at the same HBM bin with worse + perf are strictly dominated and skipped. Args: - mem_bins_per_device (int): memory bins for dynamic programming precision. + hbm_bins_per_device (int): per-device HBM bins for DP precision. + ddr_bins_per_device (int): per-device DDR bins for DP precision. """ - def __init__(self, mem_bins_per_device: int = 100) -> None: + def __init__( + self, + hbm_bins_per_device: int = 100, + ddr_bins_per_device: int = 25, + ) -> None: self._inited: bool = False - self._mem_bins_per_device: int = max(mem_bins_per_device, 1) + self._hbm_bins_per_device: int = max(hbm_bins_per_device, 1) + self._ddr_bins_per_device: int = max(ddr_bins_per_device, 1) self._sharding_options_by_fqn: OrderedDict[str, List[ShardingOption]] = ( OrderedDict() ) - # list of proposals with different total_mem, a proposal is a list of - # indices of sharding_options + # list of proposals with different total_mem; each proposal is a list + # of indices into self._sharding_options_by_fqn[fqn]. self._proposal_list: List[List[int]] = [] self._current_proposal: int = -1 - self._storage_type = "hbm" - if not torch.cuda.is_available(): - self._storage_type = "ddr" def load( self, search_space: List[ShardingOption], enumerator: Optional[Enumerator] = None, ) -> None: - """Load search space.""" + """Load search space, sorted by total (hbm + ddr) ascending.""" self._reset() - # order the sharding_option by total_storage.hbm from low to high for sharding_option in sorted( - search_space, key=lambda x: getattr(x.total_storage, self._storage_type) + search_space, + key=lambda x: (x.total_storage.hbm or 0) + (x.total_storage.ddr or 0), ): fqn = sharding_option.fqn if fqn not in self._sharding_options_by_fqn: @@ -324,105 +464,94 @@ def feedback( perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None, ) -> None: - """Feedback last proposed plan.""" - if not self._inited: - self._inited = True - table_count = len(self._sharding_options_by_fqn) - option_count = max([len(x) for x in self._sharding_options_by_fqn.values()]) - - assert storage_constraint is not None - # are we assuming the table will be evenly sharded on all devices? - max_device_mem = 0 - mem_total = 0 - for x in storage_constraint.devices: - cur_device_mem = getattr(x.storage, self._storage_type) - max_device_mem = max(max_device_mem, cur_device_mem) - mem_total += cur_device_mem - - bin_count = self._mem_bins_per_device * len(storage_constraint.devices) - bin_size = float(mem_total) / bin_count - - dp = [ - [(float("inf"), float("inf"))] * bin_count for _ in range(table_count) - ] # [table_id][mem_bin][perf, mem] - - backtrack = [ - [(-1, -1)] * bin_count for _ in range(table_count) - ] # [table_id][mem_bin][opt_id, prev_mem_bin] - - mem_by_fqn = [ - [float("inf") for _ in range(option_count)] for _ in range(table_count) - ] # memory constraint lookup table: [table_id][sharding_option_id] - perf_by_fqn = [ - [float("inf") for _ in range(option_count)] for _ in range(table_count) - ] # performance metrics lookup table: [table_id][sharding_option_id] - - # populate mem and perf for each sharding option and table: - # A[table_id][sharding_option_id] - for table_id, sharding_options in enumerate( - self._sharding_options_by_fqn.values() - ): - for opt_id, sharding_option in enumerate(sharding_options): - # prune mem of one shard > mem of one device - if ( - max( - [ - getattr(shard.storage, self._storage_type) - for shard in sharding_option.shards - ] - ) - > max_device_mem - ): - continue - mem_by_fqn[table_id][opt_id] = _bytes_to_float_bin( - getattr(sharding_option.total_storage, self._storage_type), - bin_size, - ) - perf_by_fqn[table_id][opt_id] = sharding_option.total_perf - - table_0 = 0 - for opt_j in range(option_count): - if mem_by_fqn[0][opt_j] < bin_count: - mem_i = int(mem_by_fqn[0][opt_j]) - # options are ordered in increasing order of mem, we only want to - # consider a sharding option that has higher mem and better perf - # (the smaller the better) - if dp[table_0][mem_i][0] > perf_by_fqn[table_0][opt_j]: - dp[table_0][mem_i] = ( - perf_by_fqn[table_0][opt_j], - mem_by_fqn[table_0][opt_j], - ) - backtrack[table_0][mem_i] = (opt_j, -1) - - # dp: table_count x option_count x bin_count - for table_i in range(1, table_count): - for opt_j in range(option_count): - for mem in range(bin_count): - prev_perf, perv_mem = dp[table_i - 1][mem] - if prev_perf < float("inf"): - new_mem = perv_mem + mem_by_fqn[table_i][opt_j] - if new_mem < bin_count: - new_mem_i = int(new_mem) - new_perf = prev_perf + perf_by_fqn[table_i][opt_j] - if dp[table_i][new_mem_i][0] > new_perf: - dp[table_i][new_mem_i] = (new_perf, new_mem) - backtrack[table_i][new_mem_i] = (opt_j, mem) - self._proposal_list = [] - # fill in all the proposals, starting from highest mem to lowest mem - for c in range(bin_count - 1, -1, -1): - cur_opt_idx, cur_mem_idx = backtrack[table_count - 1][c] - if cur_opt_idx >= 0: - proposal_indices = [-1] * table_count - proposal_indices[table_count - 1] = cur_opt_idx - for i in range(table_count - 2, -1, -1): - proposal_indices[i], cur_mem_idx = backtrack[i][cur_mem_idx] - self._proposal_list.append(proposal_indices) - if len(self._proposal_list) > 0: - self._current_proposal = 0 - else: + """Run 2D DP on first feedback; otherwise advance the proposal cursor.""" + if self._inited: self._current_proposal += 1 if self._current_proposal >= len(self._proposal_list): self._current_proposal = -1 + return + + self._inited = True + assert storage_constraint is not None + if not self._sharding_options_by_fqn: + return + + num_devices = len(storage_constraint.devices) + max_device_hbm = 0 + hbm_total = 0 + ddr_total = 0 + for device in storage_constraint.devices: + max_device_hbm = max(max_device_hbm, device.storage.hbm or 0) + hbm_total += device.storage.hbm or 0 + ddr_total += device.storage.ddr or 0 + # DDR is host-shared across ranks co-located on one machine, so the + # per-option fit check compares against the largest machine's DDR pool + # -- not per-device. HBM is GPU-local, so its prune stays per-device. + per_host = getattr(storage_constraint, "local_world_size", None) or num_devices + per_host = max(per_host, 1) + max_machine_ddr = 0 + for host_start in range(0, num_devices, per_host): + host_end = min(host_start + per_host, num_devices) + machine_ddr = sum( + (storage_constraint.devices[i].storage.ddr or 0) + for i in range(host_start, host_end) + ) + max_machine_ddr = max(max_machine_ddr, machine_ddr) + + hbm_bins = max(self._hbm_bins_per_device * num_devices, 1) + ddr_bins = max(self._ddr_bins_per_device * num_devices, 1) + # Collapse a degenerate axis to a single bin so we don't waste states + # on (e.g.) CPU-only topologies that have hbm == 0 everywhere. + if hbm_total == 0: + hbm_bins = 1 + if ddr_total == 0: + ddr_bins = 1 + hbm_bin_size = float(hbm_total) / hbm_bins if hbm_bins > 0 else 1.0 + ddr_bin_size = float(ddr_total) / ddr_bins if ddr_bins > 0 else 1.0 + + # Per-table option arrays in bin-units, with infeasible options pruned. + table_opts: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]] = [] + for sharding_options in self._sharding_options_by_fqn.values(): + hbm_list: List[float] = [] + ddr_list: List[float] = [] + perf_list: List[float] = [] + opt_id_list: List[int] = [] + for opt_id, sharding_option in enumerate(sharding_options): + max_shard_hbm = max( + (shard.storage.hbm or 0) for shard in sharding_option.shards + ) + max_shard_ddr = max( + (shard.storage.ddr or 0) for shard in sharding_option.shards + ) + # HBM is per-device, DDR is per-machine: see comment above. + if hbm_total > 0 and max_shard_hbm > max_device_hbm: + continue + if ddr_total > 0 and max_shard_ddr > max_machine_ddr: + continue + hbm_list.append( + (sharding_option.total_storage.hbm or 0) / hbm_bin_size + if hbm_total > 0 + else 0.0 + ) + ddr_list.append( + (sharding_option.total_storage.ddr or 0) / ddr_bin_size + if ddr_total > 0 + else 0.0 + ) + perf_list.append(sharding_option.total_perf) + opt_id_list.append(opt_id) + table_opts.append( + ( + np.asarray(hbm_list, dtype=np.float32), + np.asarray(ddr_list, dtype=np.float32), + np.asarray(perf_list, dtype=np.float32), + np.asarray(opt_id_list, dtype=np.int32), + ) + ) + + self._proposal_list = _sparse_dp_proposor_numpy(table_opts, hbm_bins, ddr_bins) + if self._proposal_list: + self._current_proposal = 0 def _extract_constraints_for_param( @@ -755,6 +884,38 @@ def calculate_shard_storages( ) +def _emit_dynamicemb_variants( + base_option: ShardingOption, +) -> List[ShardingOption]: + """Expand a dynamicemb ShardingOption into HYBRID + CACHING variants. + + Sweeps both placement modes (``caching=False`` and ``caching=True``) and, + when ``base_option.cache_params`` is unset, ten cache_load_factor values + (0.1, 0.2, ..., 1.0). The downstream 2D DP proposer picks per table the + best (mode, ratio) that fits both HBM and host topology budgets. + + ``base_option.dynamicemb_options`` must already be attached by the + caller; each returned ShardingOption owns a freshly deep-copied + ``dynamicemb_options`` instance so per-option ``caching`` mutations do + not bleed across variants. + """ + if base_option.cache_params is None: + load_factors = [(i + 1) / 10 for i in range(10)] + stats = None + else: + load_factors = [base_option.cache_params.load_factor] + stats = base_option.cache_params.stats + variants: List[ShardingOption] = [] + for caching_mode in (False, True): + for load_factor in load_factors: + opt = copy.deepcopy(base_option) + opt.cache_params = CacheParams(load_factor=load_factor, stats=stats) + # deepcopy(base_option) already produced a fresh dynamicemb_options. + opt.dynamicemb_options.caching = caching_mode # pyre-ignore [16] + variants.append(opt) + return variants + + class EmbeddingEnumerator(_EmbeddingEnumerator): """Generates embedding sharding options for given `nn.Module` with constraints. @@ -934,20 +1095,9 @@ def enumerate( sharding_option.use_dynamicemb = use_dynamicemb # pyre-ignore [16] sharding_option.dynamicemb_options = dynamicemb_options - if sharding_option.cache_params is None: - # add cache_load_factor automatic search space - for load_factor_step in range(10): - sharding_option_copy = copy.deepcopy( - sharding_option - ) - sharding_option_copy.cache_params = CacheParams( - load_factor=(load_factor_step + 1) / 10 - ) - sharding_options_per_table.append( - sharding_option_copy - ) - else: - sharding_options_per_table.append(sharding_option) + sharding_options_per_table.extend( + _emit_dynamicemb_variants(sharding_option) + ) else: sharding_options_per_table.append(sharding_option) diff --git a/tzrec/utils/plan_util_test.py b/tzrec/utils/plan_util_test.py index 868b09df..c7a12450 100644 --- a/tzrec/utils/plan_util_test.py +++ b/tzrec/utils/plan_util_test.py @@ -9,18 +9,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import unittest +from types import SimpleNamespace +from typing import List, Tuple +import numpy as np import torch +from parameterized import parameterized +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.model_parallel import get_default_sharders from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.partitioners import GreedyPerfPartitioner from torchrec.distributed.planner.proposers import GridSearchProposer from torchrec.distributed.planner.types import PlannerError, Topology from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.distributed.types import ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig -from tzrec.utils.plan_util import DynamicProgrammingProposer +from tzrec.utils.dynamicemb_util import has_dynamicemb +from tzrec.utils.plan_util import DynamicProgrammingProposer, _sparse_dp_proposor_numpy class PlanUtilTest(unittest.TestCase): @@ -136,5 +144,419 @@ def test_dp_proposer_with_prune(self) -> None: ) +class _FakeStorage: + def __init__(self, hbm, ddr): + self.hbm = hbm + self.ddr = ddr + + +class _FakeShard: + def __init__(self, hbm, ddr): + self.storage = _FakeStorage(hbm, ddr) + + +class _FakeShardingOption: + """Minimal ShardingOption stand-in: only the fields the DP proposer reads.""" + + def __init__(self, fqn, hbm, ddr, perf): + self.fqn = fqn + # Total = single shard for simplicity (single-rank assignment). + self.shards = [_FakeShard(hbm, ddr)] + self.total_storage = _FakeStorage(hbm, ddr) + self.total_perf = perf + + +def _make_topology(num_devices, hbm_per_device, ddr_per_device, local_world_size=None): + return SimpleNamespace( + devices=[ + SimpleNamespace( + storage=_FakeStorage(hbm=hbm_per_device, ddr=ddr_per_device) + ) + for _ in range(num_devices) + ], + local_world_size=local_world_size or num_devices, + ) + + +def _dense_dp_proposor_python_reference( + table_opts: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]], + hbm_bins: int, + ddr_bins: int, +) -> List[List[int]]: + """Dense T x H x D Python DP -- the oracle for sparse-K NumPy equivalence. + + Same input shape as plan_util._sparse_dp_proposor_numpy: per table, a tuple + of ``(opt_hbm, opt_ddr, opt_perf, opt_global_id)`` in bin-units. + Algorithmically identical to the in-tree DP that the NumPy version + replaced (dense state + full backtrack), tightened with a 2-row dp ring + buffer so it stays tractable on small property-test problem sizes. + """ + table_count = len(table_opts) + if table_count == 0: + return [] + INF = float("inf") + empty_state = (INF, INF, INF) + prev_dp = [[empty_state] * ddr_bins for _ in range(hbm_bins)] + backtrack = [ + [[(-1, -1, -1)] * ddr_bins for _ in range(hbm_bins)] for _ in range(table_count) + ] + + # Seed table 0. + seed_hbm, seed_ddr, seed_perf, seed_opt_id = table_opts[0] + for opt_j in range(len(seed_perf)): + hbm = float(seed_hbm[opt_j]) + ddr = float(seed_ddr[opt_j]) + perf = float(seed_perf[opt_j]) + if hbm >= hbm_bins or ddr >= ddr_bins: + continue + hbm_i, ddr_i = int(hbm), int(ddr) + if prev_dp[hbm_i][ddr_i][0] > perf: + prev_dp[hbm_i][ddr_i] = (perf, hbm, ddr) + backtrack[0][hbm_i][ddr_i] = (int(seed_opt_id[opt_j]), -1, -1) + + # Transitions. + for table_i in range(1, table_count): + cur_table_hbm, cur_table_ddr, cur_table_perf, cur_table_opt_id = table_opts[ + table_i + ] + cur_dp = [[empty_state] * ddr_bins for _ in range(hbm_bins)] + for hbm_i in range(hbm_bins): + for ddr_i in range(ddr_bins): + prev_perf, prev_hbm, prev_ddr = prev_dp[hbm_i][ddr_i] + if prev_perf == INF: + continue + for opt_j in range(len(cur_table_perf)): + new_hbm = prev_hbm + float(cur_table_hbm[opt_j]) + new_ddr = prev_ddr + float(cur_table_ddr[opt_j]) + if new_hbm >= hbm_bins or new_ddr >= ddr_bins: + continue + new_hbm_i, new_ddr_i = int(new_hbm), int(new_ddr) + new_perf = prev_perf + float(cur_table_perf[opt_j]) + if cur_dp[new_hbm_i][new_ddr_i][0] > new_perf: + cur_dp[new_hbm_i][new_ddr_i] = (new_perf, new_hbm, new_ddr) + backtrack[table_i][new_hbm_i][new_ddr_i] = ( + int(cur_table_opt_id[opt_j]), + hbm_i, + ddr_i, + ) + prev_dp = cur_dp + + # Per-HBM-bin best, decreasing HBM order. + proposals: List[List[int]] = [] + last_back = backtrack[table_count - 1] + for hbm_i in range(hbm_bins - 1, -1, -1): + best_perf, best_ddr_i = INF, -1 + for ddr_i in range(ddr_bins): + if last_back[hbm_i][ddr_i][0] >= 0 and prev_dp[hbm_i][ddr_i][0] < best_perf: + best_perf, best_ddr_i = prev_dp[hbm_i][ddr_i][0], ddr_i + if best_ddr_i < 0: + continue + cur_opt_j, prev_hbm_i, prev_ddr_i = last_back[hbm_i][best_ddr_i] + proposal_indices = [-1] * table_count + proposal_indices[table_count - 1] = cur_opt_j + for table_i in range(table_count - 2, -1, -1): + proposal_indices[table_i], prev_hbm_i, prev_ddr_i = backtrack[table_i][ + prev_hbm_i + ][prev_ddr_i] + proposals.append(proposal_indices) + return proposals + + +def _make_random_table_opts( + seed: int, + table_count: int = 5, + option_count: int = 4, + hbm_bins: int = 8, + ddr_bins: int = 8, +) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: + """Seeded random per-table option arrays for property testing. + + Per-option bin contributions stay in [0, axis/2) so combined plans can fit + several tables without saturating either axis. + """ + rng = random.Random(seed) + table_opts = [] + for _ in range(table_count): + hbm = np.asarray( + [rng.uniform(0, hbm_bins / 2) for _ in range(option_count)], + dtype=np.float32, + ) + ddr = np.asarray( + [rng.uniform(0, ddr_bins / 2) for _ in range(option_count)], + dtype=np.float32, + ) + perf = np.asarray( + [rng.uniform(1, 100) for _ in range(option_count)], dtype=np.float32 + ) + opt_id = np.arange(option_count, dtype=np.int32) + table_opts.append((hbm, ddr, perf, opt_id)) + return table_opts + + +class DynamicProgrammingProposerTest(unittest.TestCase): + """2D DP across HBM × DDR picks per-table mode under joint budgets.""" + + def _run(self, search_space, topology): + proposer = DynamicProgrammingProposer( + hbm_bins_per_device=20, ddr_bins_per_device=20 + ) + proposer.load(search_space) + # First propose returns the lowest-mem-per-table seed. + proposer.propose() + proposer.feedback(partitionable=True, storage_constraint=topology) + proposals = [] + proposal = proposer.propose() + while proposal: + proposals.append(proposal) + proposer.feedback(partitionable=True, storage_constraint=topology) + proposal = proposer.propose() + return proposals + + def test_caching_preferred_when_ddr_is_generous(self): + # Three options for one table: + # HYBRID @ x=1.0: hbm = T, ddr = 0, perf = high (HBM-only) + # HYBRID @ x=0.1: hbm = .1T, ddr = .9T, perf = high (slow) + # CACHING @ x=0.1: hbm = .1T, ddr = T, perf = low (fast — modeled hits) + opts = [ + _FakeShardingOption("table_a", hbm=1000, ddr=0, perf=50.0), + _FakeShardingOption("table_a", hbm=100, ddr=900, perf=80.0), + _FakeShardingOption("table_a", hbm=100, ddr=1000, perf=10.0), + ] + topology = _make_topology( + num_devices=2, hbm_per_device=2000, ddr_per_device=2000 + ) + proposals = self._run(opts, topology) + # Best plan must be the CACHING option (perf=10). + best = min(proposals, key=lambda p: sum(o.total_perf for o in p)) + self.assertEqual(best[0].total_perf, 10.0) + + def test_caching_rejected_when_ddr_is_tight(self): + # Host budget is only 950 — CACHING (ddr=1000) cannot fit; HYBRID can. + opts = [ + _FakeShardingOption("table_a", hbm=100, ddr=900, perf=80.0), + _FakeShardingOption("table_a", hbm=100, ddr=1000, perf=10.0), + ] + topology = _make_topology( + num_devices=1, hbm_per_device=2000, ddr_per_device=950 + ) + proposals = self._run(opts, topology) + # Every proposed plan must pick the HYBRID option (perf=80). + for p in proposals: + self.assertEqual(p[0].total_perf, 80.0) + + def test_high_factor_collapses_modes(self): + # At x=1.0 HYBRID == CACHING in HBM and CACHING.ddr = T = HYBRID.hbm. + # If we offer just the high-factor options, DP picks one of them. + opts = [ + _FakeShardingOption("table_a", hbm=1000, ddr=0, perf=50.0), # HYBRID x=1.0 + _FakeShardingOption( + "table_a", hbm=1000, ddr=1000, perf=50.0 + ), # CACHING x=1.0 + ] + topology = _make_topology( + num_devices=1, hbm_per_device=1100, ddr_per_device=2000 + ) + proposals = self._run(opts, topology) + # Either option is fine — they're tied. Just verify the proposer + # returned something feasible. + self.assertGreater(len(proposals), 0) + for p in proposals: + self.assertEqual(p[0].total_perf, 50.0) + + def test_two_tables_pick_mixed_modes_under_joint_budget(self): + # Two tables, each with HYBRID@1.0 (all-HBM, no DDR) and CACHING@0.1 + # (small HBM, full-T DDR). Topology HBM=2000 admits exactly one + # full HYBRID + one small CACHING shard (1500+100), and host DDR + # admits exactly one full-T CACHING backing (1500). Both-HYBRID is + # HBM-infeasible (3000>2000), both-CACHING is DDR-infeasible + # (3000>2000). Only the mixed plan fits. Exercises the + # cross-table DP transition at plan_util.py table_i==1. + opts = [ + _FakeShardingOption("table_a", hbm=1500, ddr=0, perf=50.0), # HYBRID@1.0 + _FakeShardingOption("table_a", hbm=100, ddr=1500, perf=40.0), # CACHING@0.1 + _FakeShardingOption("table_b", hbm=1500, ddr=0, perf=50.0), # HYBRID@1.0 + _FakeShardingOption("table_b", hbm=100, ddr=1500, perf=40.0), # CACHING@0.1 + ] + topology = _make_topology( + num_devices=1, hbm_per_device=2000, ddr_per_device=2000 + ) + proposals = self._run(opts, topology) + self.assertGreater(len(proposals), 0) + best = min(proposals, key=lambda p: sum(o.total_perf for o in p)) + styles = sorted( + "hybrid" if o.shards[0].storage.ddr == 0 else "caching" for o in best + ) + self.assertEqual(styles, ["caching", "hybrid"]) + + def test_per_machine_ddr_prune_on_multi_host_topology(self): + # 4 GPUs across 2 machines (local_world_size=2). Each machine has + # 1000 DDR; total = 2000. An option whose per-shard ddr is 1500 + # exceeds the 1000 per-machine cap and must be pruned, even + # though 1500 < ddr_total. The 900-ddr option fits. + topology = _make_topology( + num_devices=4, + local_world_size=2, + hbm_per_device=2000, + ddr_per_device=500, + ) + # 1500 > per-machine cap (1000) -> pruned, no proposal. + proposals_pruned = self._run( + [_FakeShardingOption("t", hbm=100, ddr=1500, perf=10.0)], topology + ) + self.assertEqual(proposals_pruned, []) + # 900 <= per-machine cap (1000) -> survives, proposal emitted. + proposals_fit = self._run( + [_FakeShardingOption("t", hbm=100, ddr=900, perf=10.0)], topology + ) + self.assertGreater(len(proposals_fit), 0) + + @parameterized.expand([(seed,) for seed in [0, 1, 7, 42, 1337]]) + def test_sparse_numpy_matches_dense_reference(self, seed): + # Property test: sparse-K NumPy DP must produce the same proposal set + # as the dense T x H x D Python reference oracle for any random input. + table_opts = _make_random_table_opts( + seed, table_count=5, option_count=4, hbm_bins=8, ddr_bins=8 + ) + actual_proposals = _sparse_dp_proposor_numpy(table_opts, hbm_bins=8, ddr_bins=8) + expected_proposals = _dense_dp_proposor_python_reference( + table_opts, hbm_bins=8, ddr_bins=8 + ) + self.assertEqual( + {tuple(p) for p in actual_proposals}, + {tuple(p) for p in expected_proposals}, + ) + + def test_empty_search_space_returns_empty_proposal(self): + proposer = DynamicProgrammingProposer() + proposer.load([]) + # Seed proposal is the per-table first option; with no tables the + # list is empty. + self.assertEqual(proposer.propose(), []) + # Feedback must not raise on an empty proposer (it short-circuits + # via `table_count == 0`). + topology = _make_topology( + num_devices=1, hbm_per_device=1000, ddr_per_device=1000 + ) + proposer.feedback(partitionable=True, storage_constraint=topology) + # After feedback, no proposals are available. + self.assertIsNone(proposer.propose()) + + +@unittest.skipUnless(has_dynamicemb, "dynamicemb is not installed; skipping.") +@unittest.skipUnless(torch.cuda.is_available(), "CUDA is required for dynamicemb.") +class PlanUtilDynamicEmbE2ETest(unittest.TestCase): + """End-to-end exercise of the dynamicemb planner integration.""" + + def _build_constraint(self, max_capacity=4096): + import dynamicemb + from dynamicemb.planner import DynamicEmbParameterConstraints + + opts = dynamicemb.DynamicEmbTableOptions( + max_capacity=max_capacity, + initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.UNIFORM, + lower=-0.01, + upper=0.01, + ), + eval_initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.CONSTANT, value=0.0 + ), + score_strategy=dynamicemb.DynamicEmbScoreStrategy.STEP, + ) + return DynamicEmbParameterConstraints( + use_dynamicemb=True, + sharding_types=[ShardingType.ROW_WISE.value], + compute_kernels=[EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value], + dynamicemb_options=opts, + ) + + def _build_model(self): + table = EmbeddingBagConfig( + num_embeddings=4096, + embedding_dim=32, + name="table_de", + feature_names=["feat_de"], + ) + return TestSparseNN(tables=[table], sparse_device=torch.device("meta")) + + def test_enumerate_yields_both_modes_and_all_factors(self): + from tzrec.utils.plan_util import ( + EmbeddingEnumerator as _TzrecEmbeddingEnumerator, + ) + from tzrec.utils.plan_util import ( + get_default_sharders as _tzrec_get_default_sharders, + ) + + model = self._build_model() + topology = Topology(world_size=2, compute_device="cuda") + enumerator = _TzrecEmbeddingEnumerator( + topology=topology, + batch_size=128, + fqn_constraints={"sparse.ebc.table_de": self._build_constraint()}, + ) + search_space = enumerator.enumerate( + module=model, sharders=_tzrec_get_default_sharders() + ) + self.assertEqual(len(search_space), 20) + caching_modes = sorted( + { + so.dynamicemb_options.caching + for so in search_space + if getattr(so, "use_dynamicemb", False) + } + ) + self.assertEqual(caching_modes, [False, True]) + load_factors = sorted( + { + round(so.cache_load_factor, 4) + for so in search_space + if getattr(so, "use_dynamicemb", False) + } + ) + self.assertEqual(load_factors, [round((i + 1) / 10, 4) for i in range(10)]) + # Each option must carry a non-zero perf and storage estimate. + for so in search_space: + self.assertGreater(so.total_perf, 0) + self.assertGreaterEqual(so.total_storage.hbm, 0) + self.assertGreaterEqual(so.total_storage.ddr, 0) + + def test_dp_proposer_picks_feasible_dynamicemb_plan(self): + from tzrec.utils.plan_util import ( + EmbeddingEnumerator as _TzrecEmbeddingEnumerator, + ) + from tzrec.utils.plan_util import ( + get_default_sharders as _tzrec_get_default_sharders, + ) + + model = self._build_model() + topology = Topology(world_size=2, compute_device="cuda") + enumerator = _TzrecEmbeddingEnumerator( + topology=topology, + batch_size=128, + fqn_constraints={"sparse.ebc.table_de": self._build_constraint()}, + ) + search_space = enumerator.enumerate( + module=model, sharders=_tzrec_get_default_sharders() + ) + + proposer = DynamicProgrammingProposer() + proposer.load(search_space) + proposal = proposer.propose() + self.assertIsNotNone(proposal) + proposer.feedback(partitionable=True, storage_constraint=topology) + + # At least one further proposal should be generated by the 2D DP. + count = 0 + proposal = proposer.propose() + while proposal is not None and count < 5: + count += 1 + for so in proposal: + if getattr(so, "use_dynamicemb", False): + self.assertIn(so.dynamicemb_options.caching, (False, True)) + proposer.feedback(partitionable=True, storage_constraint=topology) + proposal = proposer.propose() + self.assertGreater(count, 0) + + if __name__ == "__main__": unittest.main() diff --git a/tzrec/version.py b/tzrec/version.py index 1c5d53f1..cc32ec00 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2.12" +__version__ = "1.2.13"