Skip to content

Commit c0f4594

Browse files
authored
[Refactor] refactor noisy linear (#3082)
1 parent 8a2ac04 commit c0f4594

3 files changed

Lines changed: 468 additions & 21 deletions

File tree

sota-implementations/redq/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from __future__ import annotations
66

77
from collections.abc import Callable, Sequence
8-
98
from copy import copy
9+
from functools import partial
1010

1111
import torch
1212
from omegaconf import OmegaConf
@@ -417,7 +417,11 @@ def make_redq_model(
417417
if qvalue_net_kwargs is None:
418418
qvalue_net_kwargs = {}
419419

420-
linear_layer_class = torch.nn.Linear if not cfg.exploration.noisy else NoisyLinear
420+
linear_layer_class = (
421+
torch.nn.Linear
422+
if not cfg.exploration.noisy
423+
else partial(NoisyLinear, use_exploration_type=True)
424+
)
421425

422426
out_features_actor = (2 - gSDE) * action_spec.shape[-1]
423427
if cfg.env.from_pixels:

0 commit comments

Comments
 (0)