-
Notifications
You must be signed in to change notification settings - Fork 72
[feat] planner: pick dynamicemb HYBRID vs CACHING from topology budgets #508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
ffc0325
ead7df2
c88c4f4
6ce464b
7187c2a
c3bff19
38bd7ad
70b0fcb
c28e2b2
5280156
625e1b7
3ff8be6
984d0df
08563bb
0c0d3d7
3b7c1bf
9cc6ed2
d1727ba
fdaac44
1513444
12a7c15
20ea5c3
72d121c
0467638
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,9 +9,10 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import dataclasses | ||
| import math | ||
| import os | ||
| from typing import List, Optional, Tuple, Type, cast | ||
| from typing import Any, List, Optional, Tuple, Type, cast | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
@@ -21,7 +22,10 @@ | |
| planners, | ||
| shard_estimators, | ||
| ) | ||
| from torchrec.distributed.planner.estimator.types import HardwarePerfConfig | ||
| from torchrec.distributed.planner.estimator.types import ( | ||
| HardwarePerfConfig, | ||
| ShardPerfContext, | ||
| ) | ||
| from torchrec.distributed.planner.types import ( | ||
| ParameterConstraints, | ||
| ShardingOption, | ||
|
|
@@ -43,6 +47,89 @@ | |
| from torchrec.modules.embedding_configs import BaseEmbeddingConfig | ||
|
|
||
| from tzrec.protos import feature_pb2 | ||
| from tzrec.utils.logging_util import logger | ||
|
|
||
| _DYNAMICEMB_CACHING_X_EFF_BASE = 0.28 | ||
| _DYNAMICEMB_HYBRID_X_EFF_BASE = 0.11 | ||
| _DYNAMICEMB_X_EFF_TIEBREAK = 0.01 | ||
|
|
||
|
|
||
| def _dynamicemb_effective_cache_ratio( | ||
| cache_load_factor: Optional[float], | ||
| caching: bool, | ||
| stats: Optional[Any] = None, | ||
| ) -> float: | ||
| """Effective HBM-hit ratio for the dynamicemb perf model. | ||
|
|
||
| Returns the value passed to torchrec's perf bandwidth formula | ||
| ``bw = x_eff*hbm + (1-x_eff)*hbm_to_ddr_bw``. Larger value = faster path. | ||
|
|
||
| The ratio is derived from an on-device perf sweep, not a heuristic. | ||
| Empirical pattern (alpha=1.05 pow-law access on A10): | ||
|
|
||
| * ``x == 1.0``: the runtime *switches kernels* — when | ||
| ``total_value_memory <= local_hbm_for_values`` the dual-tier | ||
| ``HybridStorage`` / ``DynamicEmbCache`` paths are dropped in favor | ||
| of the HBM-only ``DynamicEmbStorage`` kernel | ||
| (``batched_dynamicemb_tables.py:502-604``). The ~8x jump in ``x_eff`` | ||
| between ``x=0.9`` and ``x=1.0`` is intentional and matches measured | ||
| latency, not a smoothing artifact. (A future refactor could lift | ||
| this to a discrete ``mode={HBM_ONLY, HYBRID, CACHING}`` axis on the | ||
| enumerator side rather than packing the discontinuity into ``x``.) | ||
| * ``caching=True``, ``x < 1.0``: 3.3x slower than HBM-only -> base 0.28. | ||
| * ``caching=False``, ``x < 1.0``: 6.8x slower than HBM-only -> base 0.11. | ||
|
|
||
| Within each ``x < 1.0`` block the perf is roughly flat in ratio, but we | ||
| add a tiny monotonic perturbation so the DP can break ties. | ||
|
|
||
| If ``stats`` is provided, ``1 - stats.expected_miss_rate(x)`` overrides | ||
| the heuristic verbatim (clamped to ``[0, 1]``); the caller opts in to | ||
| their own measurement. | ||
| """ | ||
| x = float(cache_load_factor) if cache_load_factor is not None else 0.0 | ||
| x = max(0.0, min(1.0, x)) | ||
| if stats is not None: | ||
| miss_rate = float(stats.expected_miss_rate(x)) | ||
| return max(0.0, min(1.0, 1.0 - miss_rate)) | ||
| if x >= 1.0: | ||
| return 1.0 | ||
| base = _DYNAMICEMB_CACHING_X_EFF_BASE if caching else _DYNAMICEMB_HYBRID_X_EFF_BASE | ||
| return base + _DYNAMICEMB_X_EFF_TIEBREAK * x | ||
|
|
||
|
|
||
| def _log_dynamicemb_table_plan( | ||
| *, | ||
| fqn: str, | ||
| cache_load_factor: float, | ||
| caching: bool, | ||
| hbm_bytes: int, | ||
| ddr_bytes: int, | ||
| ) -> None: | ||
| """Per-table mode log on rank 0. | ||
|
|
||
| cache_load_factor=1.0 forces the runtime into HBM_ONLY (host tier | ||
| dropped) regardless of ``caching`` -- mirror that override here so | ||
| the log matches the runtime, not the planner's recorded | ||
| (caching, factor). Use exact equality so a stray >1.0 fails loud | ||
| downstream instead of being silently relabelled HBM_ONLY. | ||
| """ | ||
| if int(os.environ.get("RANK", 0)) != 0: | ||
| return | ||
| if cache_load_factor == 1.0: | ||
| mode = "HBM_ONLY" | ||
| dram_bytes = 0 # runtime drops the host tier | ||
| else: | ||
| mode = "CACHING" if caching else "HYBRID" | ||
| dram_bytes = ddr_bytes | ||
| hbm_gib = hbm_bytes / (1 << 30) | ||
| dram_gib = dram_bytes / (1 << 30) | ||
| logger.info( | ||
| f"[dynamicemb plan] {fqn}: mode={mode} " | ||
| f"cache_load_factor={cache_load_factor:.2f} " | ||
| f"local_hbm={hbm_gib:.3f}GiB " | ||
| f"local_dram={dram_gib:.3f}GiB" | ||
| ) | ||
|
|
||
|
|
||
| has_dynamicemb = False | ||
| try: | ||
|
|
@@ -258,26 +345,44 @@ def _calculate_dynamicemb_table_storage_specific_size( | |
| is_hbm: bool = True, | ||
| only_values: bool = False, | ||
| bucket_capacity: int = 128, | ||
| caching: bool = False, | ||
| ) -> int: | ||
| """Calculate dynamic embedding table storage. | ||
|
|
||
| total_value_memory = max_capacity x aligned16(embedding+optimizer states) | ||
| num_buckets = max_capacity/bucket_capacity | ||
| hbm_budget = min(global_hbm_for_values//world_size, total_value_memory) + | ||
| max_capacity x (key<8byte> + score<8byte> + digest<1byte>) + | ||
| num_buckets x (bucket_size<4byte>) | ||
| ddr_budget = max(total_value_memory - global_hbm_for_values//world_size, 0) | ||
| """Per-shard storage size for a dynamicemb table -- HBM or DDR (bytes). | ||
|
|
||
| Byte budget (single shard, rows x dim): | ||
|
|
||
| value_bytes_per_row = round_up16(dim * (1 + opt_mult) * element) | ||
| total_value_memory = align(rows) * value_bytes_per_row | ||
| num_buckets = align(rows) / bucket_capacity | ||
|
|
||
| hbm_budget = cache_ratio * total_value_memory # values | ||
| + align(rows) * (key<8B> + score<8B> + digest<1B>) # per-row | ||
| + num_buckets * bucket_header<4B> # per-bucket | ||
|
|
||
| ddr_budget = HYBRID (caching=False): (1 - cache_ratio) * total_value_memory | ||
| CACHING (caching=True): total_value_memory # full backing | ||
|
|
||
| HYBRID hash-partitions values across HBM and host; ``cache_ratio`` is | ||
| HBM's value share. CACHING keeps the full backing store on host and | ||
| uses HBM as a hot-row cache of size | ||
| ``cache_ratio * total_value_memory``. Hash-table metadata | ||
| (key + score + digest + bucket header) is accounted on HBM only -- | ||
| matches the existing tzrec convention. | ||
| """ | ||
| if cache_ratio is None: | ||
| cache_ratio = 1.0 | ||
| if is_hbm: | ||
| value_ratio = cache_ratio | ||
| else: | ||
| value_ratio = 1.0 if caching else (1.0 - cache_ratio) | ||
| return math.ceil( | ||
| align_to_table_size(size[0]) | ||
| * ( | ||
| _round_up( | ||
| math.ceil(size[1] * (1 + optimizer_multipler) * element_size), | ||
| 16, | ||
| ) | ||
| * (cache_ratio if is_hbm else 1 - cache_ratio) | ||
| * value_ratio | ||
| + (8 + 8 + 1 + 4 / bucket_capacity) * (is_hbm and not only_values) | ||
| ) | ||
| ) | ||
|
|
@@ -363,6 +468,13 @@ def _to_sharding_plan( | |
| dist_type="roundrobin", | ||
| dynamicemb_options=dynamicemb_options, | ||
| ) | ||
| _log_dynamicemb_table_plan( | ||
| fqn=f"{sharding_option.path}.{sharding_option.name}", | ||
| cache_load_factor=float(sharding_option.cache_load_factor), | ||
| caching=bool(dynamicemb_options.caching), | ||
| hbm_bytes=int(shards[0].storage.hbm), | ||
| ddr_bytes=int(shards[0].storage.ddr), | ||
| ) | ||
| else: | ||
| module_plan[sharding_option.name] = ParameterSharding( | ||
| sharding_spec=sharding_spec, | ||
|
|
@@ -413,13 +525,74 @@ def _customized_kernel_aware_get_device_bw( | |
| # pyre-ignore [9] | ||
| HardwarePerfConfig.get_device_bw = _customized_kernel_aware_get_device_bw | ||
|
|
||
| _orig_build_shard_perf_contexts = ( | ||
| ShardPerfContext.build_shard_perf_contexts.__func__ | ||
| ) | ||
|
|
||
| def _dynamicemb_aware_build_shard_perf_contexts( | ||
| cls, # pyre-ignore [2] | ||
| config, # pyre-ignore [2] | ||
| shard_sizes, # pyre-ignore [2] | ||
| sharding_option, # pyre-ignore [2] | ||
| topology, # pyre-ignore [2] | ||
| constraints, # pyre-ignore [2] | ||
| sharder, # pyre-ignore [2] | ||
|
Comment on lines
+532
to
+539
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The wrapper hard-codes the first six positional parameters ( Worth capturing |
||
| *args, # pyre-ignore [2] | ||
| **kwargs, # pyre-ignore [2] | ||
| ): | ||
| """Inject the empirical x_eff into the perf estimator for both modes. | ||
|
|
||
| Temporarily replace ``sharding_option.cache_params`` with a clone | ||
| whose ``load_factor`` is the empirically-fitted x_eff for the | ||
| (mode, cache_load_factor) combination. Restored before returning so | ||
| the (separately invoked) storage estimator still sees the un-boosted | ||
| ratio. | ||
| """ | ||
| dynamicemb_options = getattr(sharding_option, "dynamicemb_options", None) | ||
| original_cache_params = sharding_option.cache_params | ||
| if dynamicemb_options is not None: | ||
| caching = bool(getattr(dynamicemb_options, "caching", False)) | ||
| stats = original_cache_params.stats if original_cache_params else None | ||
| x_eff = _dynamicemb_effective_cache_ratio( | ||
| sharding_option.cache_load_factor, caching=caching, stats=stats | ||
| ) | ||
| sharding_option.cache_params = ( | ||
| dataclasses.replace(original_cache_params, load_factor=x_eff) | ||
| if original_cache_params is not None | ||
| else CacheParams(load_factor=x_eff) | ||
| ) | ||
| # try/finally so an estimator exception cannot leak the boosted | ||
| # cache_params clone into the storage estimator's view of the | ||
| # same ShardingOption. | ||
| try: | ||
| result = _orig_build_shard_perf_contexts( | ||
| cls, | ||
| config, | ||
| shard_sizes, | ||
| sharding_option, | ||
| topology, | ||
| constraints, | ||
| sharder, | ||
| *args, | ||
| **kwargs, | ||
| ) | ||
| finally: | ||
| sharding_option.cache_params = original_cache_params | ||
| return result | ||
|
|
||
| # pyre-ignore [9] | ||
| ShardPerfContext.build_shard_perf_contexts = classmethod( | ||
| _dynamicemb_aware_build_shard_perf_contexts | ||
| ) | ||
|
|
||
| def _calculate_dynamicemb_storage_specific_sizes( | ||
| tensor: torch.Tensor, | ||
| shard_sizes: List[List[int]], | ||
| optimizer_class: Optional[Type[torch.optim.Optimizer]] = None, | ||
| cache_ratio: float = 1.0, | ||
| is_inference: bool = False, | ||
| bucket_capacity: int = 128, | ||
| caching: bool = False, | ||
| ) -> Tuple[List[int], List[int]]: | ||
| """Calculate storage for dynamicemb.""" | ||
| optimizer_multipler = 0.0 | ||
|
|
@@ -437,6 +610,7 @@ def _calculate_dynamicemb_storage_specific_sizes( | |
| cache_ratio, | ||
| is_hbm=True, | ||
| bucket_capacity=bucket_capacity, | ||
| caching=caching, | ||
| ) | ||
| for size in shard_sizes | ||
| ] | ||
|
|
@@ -449,6 +623,7 @@ def _calculate_dynamicemb_storage_specific_sizes( | |
| cache_ratio, | ||
| is_hbm=False, | ||
| bucket_capacity=bucket_capacity, | ||
| caching=caching, | ||
| ) | ||
| for size in shard_sizes | ||
| ] | ||
|
|
@@ -496,7 +671,10 @@ def dynamicemb_calculate_shard_storages( | |
| factors. | ||
| num_poolings (List[float]): average number of poolings per sample | ||
| (typically 1.0). | ||
| caching_ratio (float): ratio of HBM to DDR memory for UVM caching. | ||
| caching_ratio (float): cache_load_factor for the dynamicemb table. | ||
| In HYBRID mode HBM holds this fraction of values and host | ||
| holds the remainder; in CACHING mode HBM is a hot-row cache | ||
| of this fraction and host holds the full backing store. | ||
| is_pooled (bool): True if embedding output is pooled (ie. `EmbeddingBag`), | ||
| False if unpooled/sequential (ie. `Embedding`). | ||
| input_data_type_size (int): number of bytes of input data type. | ||
|
|
@@ -535,6 +713,7 @@ def dynamicemb_calculate_shard_storages( | |
| cache_ratio=caching_ratio if caching_ratio else 1.0, | ||
| is_inference=is_inference, | ||
| bucket_capacity=dynamicemb_options.bucket_capacity, | ||
| caching=bool(getattr(dynamicemb_options, "caching", False)), | ||
| ) | ||
| ) | ||
| counter_hbm_specific_size = 0 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sharp discontinuity at
x=1.0. HYBRID@x=0.99 →0.1199; HYBRID@x=1.00 →1.0— an ~8× jump for a 1% ratio change. The DP will reliably prefer x=1.0 over x=0.99 by a huge perf margin, then prefer CACHING@x=0.5 (0.285) over HYBRID@x=0.9 (0.119) even on workloads where HYBRID@0.9 is plainly faster in reality. If the empirical sweep really shows a step at x=1.0 because the runtime drops the host tier, please call that out as "x=1.0 = HBM-only kernel, not the same algorithm as x=0.99" — and consider whether the enumerator should emit a discretemode={HBM_ONLY, HYBRID, CACHING}axis rather than packing the discontinuity into the samexknob.