We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 8a2ac04 commit c0f4594Copy full SHA for c0f4594
3 files changed
sota-implementations/redq/utils.py
@@ -5,8 +5,8 @@
5
from __future__ import annotations
6
7
from collections.abc import Callable, Sequence
8
-
9
from copy import copy
+from functools import partial
10
11
import torch
12
from omegaconf import OmegaConf
@@ -417,7 +417,11 @@ def make_redq_model(
417
if qvalue_net_kwargs is None:
418
qvalue_net_kwargs = {}
419
420
- linear_layer_class = torch.nn.Linear if not cfg.exploration.noisy else NoisyLinear
+ linear_layer_class = (
421
+ torch.nn.Linear
422
+ if not cfg.exploration.noisy
423
+ else partial(NoisyLinear, use_exploration_type=True)
424
+ )
425
426
out_features_actor = (2 - gSDE) * action_spec.shape[-1]
427
if cfg.env.from_pixels:
0 commit comments