From 856b7d2ef289f1ec8d1c6d12cb9f3e58bbd5b2ce Mon Sep 17 00:00:00 2001 From: hscspring Date: Wed, 1 Feb 2023 11:00:28 +0800 Subject: [PATCH 1/2] fix: OnPolicyAlgorithm doesnot have the parameter: create_eval_env --- rl4lms/algorithms/nlpo/nlpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rl4lms/algorithms/nlpo/nlpo.py b/rl4lms/algorithms/nlpo/nlpo.py index acdc6d70..cf5895b5 100644 --- a/rl4lms/algorithms/nlpo/nlpo.py +++ b/rl4lms/algorithms/nlpo/nlpo.py @@ -114,7 +114,7 @@ def __init__( use_sde=False, sde_sample_freq=-1, tensorboard_log=tensorboard_log, - create_eval_env=create_eval_env, + # create_eval_env=create_eval_env, policy_kwargs=policy_kwargs, verbose=verbose, seed=seed, From 58b06328c73182d522a4b4cdf49dc88e8b4efb5b Mon Sep 17 00:00:00 2001 From: hscspring Date: Wed, 1 Feb 2023 14:47:45 +0800 Subject: [PATCH 2/2] fix: It's better to assign dtype in DictSpace of TextGenEnv explicitly --- rl4lms/envs/text_generation/env.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/rl4lms/envs/text_generation/env.py b/rl4lms/envs/text_generation/env.py index bf2a7981..1221be4f 100644 --- a/rl4lms/envs/text_generation/env.py +++ b/rl4lms/envs/text_generation/env.py @@ -1,6 +1,7 @@ from cmath import inf from typing import Dict, Tuple, Optional, List +import numpy as np import torch from gym import Env, spaces from gym.spaces.dict import Dict as DictSpace @@ -57,24 +58,24 @@ def __init__( # we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited # while creating rollout buffers, observations are concatenated for each key "prompt_or_input_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(self._max_text_length,) + low=0, high=self._vocab_size, shape=(self._max_text_length,), dtype=np.int64 ), "prompt_or_input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self._max_text_length,) + low=0, high=1, shape=(self._max_text_length,), dtype=np.int64 ), "context_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(self.max_steps,) + low=0, high=self._vocab_size, shape=(self.max_steps,), dtype=np.int64 ), "context_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self.max_steps,) + low=0, high=1, shape=(self.max_steps,), dtype=np.int64 ), "input_encoded_pt": spaces.Box( low=0, high=self._vocab_size, - shape=(self._max_text_length + self.max_steps,), + shape=(self._max_text_length + self.max_steps,), dtype=np.int64 ), "input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self._max_text_length + self.max_steps,) + low=0, high=1, shape=(self._max_text_length + self.max_steps,), dtype=np.int64 ), } )