Skip to content

Commit deff70b

Browse files
committed
proper config management to handle two disticnt class of policies
1 parent 4592d48 commit deff70b

1 file changed

Lines changed: 13 additions & 1 deletion

File tree

pufferlib/pufferl.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4232,7 +4232,13 @@ def apply_common_mining_worker_kwargs(worker_kwargs):
42324232
if policy is not None and policy.__class__.__name__ != "TargetDrive":
42334233
raise pufferlib.APIUsageError("target-actor traffic mining requires a TargetDrive policy")
42344234

4235-
target_args = _prepare_target_policy_args(args, target_policy_path)
4235+
traffic_args = args
4236+
if traffic_target_policy_path is not None and args.get("traffic_policy_config") is not None:
4237+
traffic_args = copy.deepcopy(args)
4238+
traffic_args["target_policy_config"] = args.get("traffic_policy_config")
4239+
traffic_args["train"]["target_policy_config"] = args.get("traffic_policy_config")
4240+
4241+
target_args = _prepare_target_policy_args(traffic_args, target_policy_path)
42364242
target_env = _make_target_policy_env_view(vecenv.driver_env, target_args)
42374243
policy = policy or load_policy(target_args, vecenv, env_name, policy_env=target_env)
42384244
policy._puffer_policy_env = target_env
@@ -4809,6 +4815,12 @@ def load_config(env_name, config_dir=None):
48094815
default=None,
48104816
help="Optional target policy config.yaml used to reconstruct frozen target architecture/observation layout",
48114817
)
4818+
parser.add_argument(
4819+
"--traffic-policy-config",
4820+
type=str,
4821+
default=None,
4822+
help="Optional config.yaml used to reconstruct target-actor traffic policy architecture/observation layout",
4823+
)
48124824
parser.add_argument(
48134825
"--load-id", type=str, default=None, help="Kickstart/eval from from a finished Wandb/Neptune run"
48144826
)

0 commit comments

Comments
 (0)