diff --git a/rl4lms/algorithms/nlpo/nlpo.py b/rl4lms/algorithms/nlpo/nlpo.py index acdc6d70..cf5895b5 100644 --- a/rl4lms/algorithms/nlpo/nlpo.py +++ b/rl4lms/algorithms/nlpo/nlpo.py @@ -114,7 +114,7 @@ def __init__( use_sde=False, sde_sample_freq=-1, tensorboard_log=tensorboard_log, - create_eval_env=create_eval_env, + # create_eval_env=create_eval_env, policy_kwargs=policy_kwargs, verbose=verbose, seed=seed, diff --git a/rl4lms/envs/text_generation/env.py b/rl4lms/envs/text_generation/env.py index bf2a7981..1221be4f 100644 --- a/rl4lms/envs/text_generation/env.py +++ b/rl4lms/envs/text_generation/env.py @@ -1,6 +1,7 @@ from cmath import inf from typing import Dict, Tuple, Optional, List +import numpy as np import torch from gym import Env, spaces from gym.spaces.dict import Dict as DictSpace @@ -57,24 +58,24 @@ def __init__( # we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited # while creating rollout buffers, observations are concatenated for each key "prompt_or_input_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(self._max_text_length,) + low=0, high=self._vocab_size, shape=(self._max_text_length,), dtype=np.int64 ), "prompt_or_input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self._max_text_length,) + low=0, high=1, shape=(self._max_text_length,), dtype=np.int64 ), "context_encoded_pt": spaces.Box( - low=0, high=self._vocab_size, shape=(self.max_steps,) + low=0, high=self._vocab_size, shape=(self.max_steps,), dtype=np.int64 ), "context_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self.max_steps,) + low=0, high=1, shape=(self.max_steps,), dtype=np.int64 ), "input_encoded_pt": spaces.Box( low=0, high=self._vocab_size, - shape=(self._max_text_length + self.max_steps,), + shape=(self._max_text_length + self.max_steps,), dtype=np.int64 ), "input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self._max_text_length + self.max_steps,) + low=0, high=1, shape=(self._max_text_length + self.max_steps,), dtype=np.int64 ), } )