[Refactor] Dynamic value-estimator registry across all loss modules#3780
Open
vmoens wants to merge 1 commit into
Open
[Refactor] Dynamic value-estimator registry across all loss modules#3780vmoens wants to merge 1 commit into
vmoens wants to merge 1 commit into
Conversation
Replaces the hard-coded if/elif dispatch in every loss's
`make_value_estimator` with a registry that estimators self-register
into. Adding a new estimator now requires only the
`@register_value_estimator` decorator on the class — no edits in any
existing loss file.
What changed:
- `torchrl/objectives/utils.py`:
- new `_VALUE_ESTIMATOR_REGISTRY` dict;
- `register_value_estimator(value_type, default_kwargs=...)` decorator;
- `get_value_estimator_entry(value_type)` lookup that accepts both
`ValueEstimators` enum members and lowercase string aliases (e.g.
`"gae"`) — handy for hydra/yaml configs;
- `build_value_estimator(loss, value_type, **hp)` low-level builder;
- `dispatch_value_estimator(loss, value_type, *, supported,
tensor_keys, **hp)` — the high-level helper most losses now use.
It validates `value_type` against the loss's `supported` set,
builds the estimator via the registry, and applies `set_keys`;
- `default_value_kwargs()` is now a thin shim that reads from the
registry (back-compat preserved).
- `torchrl/objectives/value/advantages.py`: `ValueEstimatorBase` grows
a `for_loss(cls, loss_module, **hp)` classmethod that the registry
uses to wire each estimator against a loss. Default picks
`loss.critic_network` (falling back to `loss.value_network`), but
the loss can pass `value_network=<the module>` through
`dispatch_value_estimator` to override. `VTrace` overrides
`for_loss` to also wire the actor and handle the functional-actor
deep-copy. All five built-in estimators (TD0/TD1/TDLambda/GAE/VTrace)
carry a `@register_value_estimator` decorator with their default
kwargs.
- Every loss declares a `SUPPORTED_VALUE_ESTIMATORS` tuple and its
`make_value_estimator` body collapses to a single
`dispatch_value_estimator(...)` call:
- `ppo.py`, `a2c.py`, `reinforce.py`: all five estimators supported.
- `dqn.py` (DQN + QMixer), `iql.py`, `ddpg.py`, `td3.py`, `td3_bc.py`,
`sac.py`, `cql.py`, `redq.py`, `dreamer.py`, `dreamer_v3.py`,
`crossq.py`, `deprecated.py`, `multiagent/qmixer.py`: `{TD0, TD1,
TDLambda}` (QMixer only TD0).
The instance/class fast-path for power users who pass a
`ValueEstimatorBase` directly stays at the top of every body.
Tests: new `TestValueEstimatorRegistry` suite (8 tests) in
`test/objectives/test_loss_module.py` covers all built-ins being
registered, string-alias resolution, error paths for bad keys, custom
estimator registration, `dispatch_value_estimator`'s unsupported-set
check, the `value_network` override path, and back-compat between
`default_value_kwargs()` and the registry.
Net: ~700 lines deleted, ~600 added.
Co-authored-by: Cursor <cursoragent@cursor.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3780
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 1e03f13 with merge base 258dfad ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
vmoens
added a commit
to theap06/rl
that referenced
this pull request
May 19, 2026
Adds first-class objectives for cooperative continuous-control MARL, plus the supporting infrastructure they need. ## What's new - ``torchrl.objectives.multiagent.MAPPOLoss`` -- centralised-critic, decentralised-actor PPO (Yu et al. 2022, https://arxiv.org/abs/2103.01955). Subclasses ``ClipPPOLoss``; defaults the value estimator to ``MultiAgentGAE``, defaults ``normalize_advantage_exclude_dims=(-2,)``, and optionally accepts a ``ValueNorm`` for the critic-stability trick from the paper (Table 13). - ``torchrl.objectives.multiagent.IPPOLoss`` -- independent-learner counterpart (de Witt et al. 2020, https://arxiv.org/abs/2011.09533). Each agent has its own local critic; no centralised state required. - ``torchrl.objectives.value.MultiAgentGAE`` -- ``GAE`` subclass that broadcasts team-shared ``reward`` / ``done`` / ``terminated`` (shape ``[*B, T, 1]``) across the agent dim before the vec-GAE recursion, so users don't have to manually replicate signals. Per-agent rewards pass through unchanged (competitive settings). New ``ValueEstimators.MAGAE`` enum entry. - ``torchrl.modules.ValueNorm`` / ``PopArtValueNorm`` / ``RunningValueNorm`` -- abstract ``ValueNorm`` (an ``nn.Module``) with two implementations: PopArt-style EMA (van Hasselt et al. 2019, https://arxiv.org/abs/1809.04474) and exact Welford running stats. Plugs into ``MAPPOLoss(value_norm=...)`` and normalises both the critic target and the prediction so the MSE / smooth-L1 distance stays on a fixed scale as reward scales drift. Composes correctly with the parent ``ClipPPOLoss`` features (``clip_value``, ``separate_losses``, ``log_explained_variance``) via the new ``_critic_loss_inputs`` hook. ## Design notes **Two classes, no centralized boolean flag.** The MAPPO / IPPO structural difference is small but the construction recipes differ in ways that matter (centralised critic vs. per-agent critic), so we expose them as separate named classes; the docstring on each spells out the recipe explicitly. **MAGAE dispatch in plain PPO / A2C / Reinforce.** Adding ``ValueEstimators.MAGAE`` to the enum would break every parent test that parametrises over ``list(ValueEstimators)`` unless every ``make_value_estimator`` knows the new value. We dispatch MAGAE through ``MultiAgentGAE`` in those losses (~5 lines per file). When the registry from pytorch#3780 lands, the explicit branches collapse to a single ``@register_value_estimator`` decorator. **ValueNorm placement.** Lives under ``torchrl/modules/`` rather than ``torchrl/objectives/utils/`` because it's a stateful learnable component that participates in ``.to(device)`` / ``state_dict()``. ## Verification - ``pytest test/objectives/test_mappo.py`` -- 25/25 passing. Includes regression tests for: * ``value_norm`` registered exactly once in ``state_dict()`` (no duplicate ``_value_norm_module.*`` keys); * ``state_dict()`` save / load round-trips the running stats; * ``clip_value`` + ``value_norm`` still produces ``value_clip_fraction`` in the output; * ``separate_losses=True`` + ``value_norm`` correctly detaches so critic-loss grads do not flow into actor params; * no spurious annotation warnings on instantiation. - ``pytest test/objectives/`` -- 7280 passing, 2319 skipped (no regressions). - ``examples/multiagent/mappo_vmas.py --algo mappo --frames 200_000`` provides a minimal end-to-end VMAS Navigation smoke recipe. Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Context
Today every loss in
torchrl.objectivesships its ownmake_value_estimatorbody with a hand-codedif/elifchain overValueEstimators.{TD0,TD1,GAE,TDLambda,VTrace}. The chain knows the class names, the default hyperparameters, and any per-estimator construction quirks (e.g. V-Trace needs the actor). Adding a new estimator therefore means touching ~15 loss files plus their tests.This PR replaces that with a registry: estimators self-register via a decorator at class-definition time, and every loss dispatches through a single helper.
This is the registry half of #3748 — pulled out into its own PR per review feedback so that the MAPPO/IPPO/MultiAgentGAE/ValueNorm work in #3748 stays scoped to its features, and the registry refactor can be reviewed independently and extended to all losses (the original PR only converted PPO / A2C / Reinforce, leaving the headline benefit unfulfilled for the other ~12 loss classes).
What's new
torchrl.objectives.utils_VALUE_ESTIMATOR_REGISTRY: dict[ValueEstimators, _ValueEstimatorRegistryEntry]@register_value_estimator(value_type, *, default_kwargs=...)— class decorator that estimators apply at definition time.get_value_estimator_entry(value_type)— accepts bothValueEstimatorsenum members and lowercase string aliases (`"gae"`, `"vtrace"`, ...). Handy for hydra/yaml configs.build_value_estimator(loss, value_type, **hp)— low-level: looks up the class via the registry, mergesentry.default_kwargswith the caller'shp, then callscls.for_loss(loss, **merged).dispatch_value_estimator(loss, value_type, *, supported, tensor_keys, **hp)— high-level helper most losses use. It:value_typeagainst the loss'ssupportedset;gammafromloss.gammaif present;loss._value_estimator/loss.value_type;set_keys(**tensor_keys).default_value_kwargs()is now a 1-line shim over the registry (back-compat preserved).torchrl.objectives.value.advantagesValueEstimatorBasegrows afor_loss(cls, loss_module, **hp)classmethod. Default behaviour: pickloss.critic_network, fall back toloss.value_network. If the caller (i.e. the loss) passesvalue_network=<module>inhp, that wins.VTraceoverridesfor_lossto also pluck the actor and handle the functional-actor deep-copy.@register_value_estimator(...)decorator with their default kwargs.Per-loss conversion
Every loss now declares a
SUPPORTED_VALUE_ESTIMATORStuple, and itsmake_value_estimatorbody collapses to a singledispatch_value_estimator(...)call:PPOLoss,A2CLoss,ReinforceLossDQNLoss,IQLLoss,DiscreteIQLLoss,DDPGLoss,TD3Loss,TD3BCLoss,SACLoss,DiscreteSACLoss,CQLLoss,DiscreteCQLLoss,REDQLoss,REDQLoss_deprecated,DreamerActorLoss,DreamerV3ActorLoss,CrossQLossQMixerLossThe instance/class fast-path for power users who pass a
ValueEstimatorBasedirectly stays at the top of every body.Why a partial conversion would have been bad
The version in #3748 only converted PPO/A2C/Reinforce. That left the if/elif chains intact in 12 other losses, and the PR still had to hand-edit 8 test files to add new estimators to the "raises NotImplementedError" lists. The registry's whole pitch ("no edits in existing losses when adding a new estimator") only holds if every loss participates. This PR finishes the conversion so the next estimator addition — including
MAGAEin #3748 — really does require zero edits to existing losses.Verification
TestValueEstimatorRegistrysuite (8 tests) intest/objectives/test_loss_module.py:\"gae\"↔ValueEstimators.GAE)dispatch_value_estimatorrejects estimators outside the supported setvalue_network=...override pathdefault_value_kwargs()matches the registrypytest test/objectives/test_ppo.py test/objectives/test_dqn.py test/objectives/test_ddpg.py test/objectives/test_sac.py test/objectives/test_iql.py test/objectives/test_cql.py test/objectives/test_dreamer.py— 4521 passed.Stats
torchrl/objectives/utils.py(one file, ~120 lines).Out of scope (for #3748 to add on top)
MAGAE/MultiAgentGAEregistration — lives in [Feature] MAPPOLoss + IPPOLoss + MultiAgentGAE + ValueNorm #3748 with the rest of the MAPPO feature.Made with Cursor