Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,8 @@ sysinfo.txt
*.apk
*.unitypackage

UnitySDK.log
UnitySDK.log
.venv/
PR.md
PR_Arguments.md
PR_Opening.md
100 changes: 94 additions & 6 deletions brax/io/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,104 @@

import pickle
from typing import Any
import warnings

from etils import epath
import jax
import msgpack
import numpy as np


def load_params(path: str) -> Any:
with epath.Path(path).open('rb') as fin:
buf = fin.read()
return pickle.loads(buf)
class SecurityWarning(UserWarning):
"""Warning category for insecure model loading."""

pass


def _encode_pytree(obj: Any) -> Any:
"""Recursively converts a Pytree into msgpack-compatible types."""
if isinstance(obj, (jax.Array, np.ndarray)):
# Standard metadata-preserving array format
return {
'__type__': 'array',
'data': obj.tobytes(),
'shape': obj.shape,
'dtype': str(obj.dtype),
}
# Handle flax.struct.dataclass and NamedTuples (like RunningStatisticsState)
if hasattr(obj, '__dict__') and hasattr(obj, '_asdict'):
return {
'__type__': obj.__class__.__name__,
'data': {k: _encode_pytree(v) for k, v in obj._asdict().items()},
}
# Handle nested containers
if isinstance(obj, dict):
return {k: _encode_pytree(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return {
'__type__': obj.__class__.__name__,
'data': [_encode_pytree(x) for x in obj],
}
return obj


def _decode_pytree(obj: Any) -> Any:
"""Reconstructs Pytree types from serialized dictionaries."""
if isinstance(obj, dict):
type_name = obj.get('__type__')
# Reconstruct Arrays
if type_name == 'array':
return jax.numpy.frombuffer(obj['data'], dtype=obj['dtype']).reshape(
obj['shape']
)

# Reconstruct specialized Brax/Flax types
data = obj.get('data')
if type_name == 'RunningStatisticsState':
from brax.training.acme import running_statistics

return running_statistics.RunningStatisticsState(**_decode_pytree(data))
if type_name == 'UInt64':
from brax.training import types

return types.UInt64(**_decode_pytree(data))

# Reconstruct containers
if type_name == 'tuple':
return tuple(_decode_pytree(x) for x in data)
if type_name == 'list':
return [_decode_pytree(x) for x in data]

# Generic nested dicts
return {k: _decode_pytree(v) for k, v in (data if data else obj).items()}

return obj


def save_params(path: str, params: Any):
"""Saves parameters in flax format."""
"""Saves parameters safely using msgpack."""
encoded = _encode_pytree(params)
with epath.Path(path).open('wb') as fout:
fout.write(pickle.dumps(params))
fout.write(msgpack.packb(encoded))


def load_params(path: str, allow_pickle: bool = False) -> Any:
"""Loads parameters safely, with a security-gated legacy path."""
with epath.Path(path).open('rb') as fin:
buf = fin.read()

if buf.startswith(b'\x80'): # Pickle Protocol 2+ Header
if not allow_pickle:
raise RuntimeError(
'SECURITY ERROR: Insecure pickle file detected. For security reasons,'
' loading is blocked. Use allow_pickle=True if you trust the source.'
)

warnings.warn(
'SECURITY WARNING: Loading legacy pickle files is insecure and '
'deprecated. Please migrate your models to the new secure format.',
category=SecurityWarning,
)
return pickle.loads(buf)

return _decode_pytree(msgpack.unpackb(buf))
96 changes: 96 additions & 0 deletions brax/io/model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2026 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for parameter saving/loading."""

import os
import pickle
import tempfile

from absl.testing import absltest
import jax
import jax.numpy as jnp

from brax.io import model as brax_model


class ModelTest(absltest.TestCase):

def test_save_load_params(self):
"""Verifies that Msgpack serialization preserves Pytree data integrity."""
params = {
'policy': {
'w': jnp.ones((4, 8)),
'b': jnp.zeros((8,)),
},
'stats': (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])),
'list': [jnp.array(1), jnp.array(2)],
}

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'params.msgpack')
brax_model.save_params(path, params)
loaded_params = brax_model.load_params(path)

# Check structure and values
import numpy as np

jax.tree_util.tree_map(np.testing.assert_allclose, params, loaded_params)

def test_pickle_security_block(self):
"""Verifies that legacy pickle files are blocked by default."""
params = {'test': 123}
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'params.pkl')
with open(path, 'wb') as f:
f.write(pickle.dumps(params))

with self.assertRaisesRegex(
RuntimeError, 'SECURITY ERROR: Insecure pickle file'
):
brax_model.load_params(path)

def test_pickle_allow_explicit(self):
"""Verifies that legacy files can still be loaded with explicit flag."""
params = {'test': 456}
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'params.pkl')
with open(path, 'wb') as f:
f.write(pickle.dumps(params))

with self.assertWarns(brax_model.SecurityWarning):
loaded_params = brax_model.load_params(path, allow_pickle=True)

self.assertEqual(params, loaded_params)

def test_rce_prevention(self):
"""Verifies that malicious payloads are blocked before deserialization."""

class Malicious:

def __reduce__(self):
return (os.system, ('echo RCE_EXPLOITED',))

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'malicious.pkl')
with open(path, 'wb') as f:
f.write(pickle.dumps(Malicious()))

# Should raise RuntimeError and NOT execute the payload
with self.assertRaises(RuntimeError):
brax_model.load_params(path)


if __name__ == '__main__':
absltest.main()
13 changes: 10 additions & 3 deletions brax/training/agents/apg/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

from absl.testing import absltest
from absl.testing import parameterized
import jax

from brax import envs
from brax.training.acme import running_statistics
from brax.training.agents.apg import networks as apg_networks
from brax.training.agents.apg import train as apg
import jax


class APGTest(parameterized.TestCase):
Expand Down Expand Up @@ -62,8 +63,14 @@ def testNetworkEncoding(self, normalize_observations):
env.observation_size, env.action_size, normalize_fn
)
inference = apg_networks.make_inference_fn(apg_network)
byte_encoding = pickle.dumps(params)
decoded_params = pickle.loads(byte_encoding)
import tempfile

from brax.io import model as brax_model

with tempfile.TemporaryDirectory() as tmpdir:
path = f'{tmpdir}/params.msgpack'
brax_model.save_params(path, params)
decoded_params = brax_model.load_params(path)

# Compute one action.
state = env.reset(jax.random.PRNGKey(0))
Expand Down
13 changes: 10 additions & 3 deletions brax/training/agents/ars/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

from absl.testing import absltest
from absl.testing import parameterized
import jax

from brax import envs
from brax.training.acme import running_statistics
from brax.training.agents.ars import networks as ars_networks
from brax.training.agents.ars import train as ars
import jax


class ARSTest(parameterized.TestCase):
Expand All @@ -44,8 +45,14 @@ def testModelEncoding(self, normalize_observations):
env.observation_size, env.action_size, normalize_fn
)
inference = ars_networks.make_inference_fn(ars_network)
byte_encoding = pickle.dumps(params)
decoded_params = pickle.loads(byte_encoding)
import tempfile

from brax.io import model as brax_model

with tempfile.TemporaryDirectory() as tmpdir:
path = f'{tmpdir}/params.msgpack'
brax_model.save_params(path, params)
decoded_params = brax_model.load_params(path)

# Compute one action.
state = env.reset(jax.random.PRNGKey(0))
Expand Down
13 changes: 10 additions & 3 deletions brax/training/agents/bc/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

from absl.testing import absltest
from absl.testing import parameterized
import jax

from brax import envs
from brax.training.acme import running_statistics
from brax.training.agents.bc import networks as bc_networks
from brax.training.agents.bc import train as bc
import jax


class BCTest(parameterized.TestCase):
Expand Down Expand Up @@ -107,8 +108,14 @@ def testNetworkEncoding(self):
make_inference = bc_networks.make_inference_fn(bc_network)

# Test serialization and deserialization
byte_encoding = pickle.dumps(params)
decoded_params = pickle.loads(byte_encoding)
import tempfile

from brax.io import model as brax_model

with tempfile.TemporaryDirectory() as tmpdir:
path = f'{tmpdir}/params.msgpack'
brax_model.save_params(path, params)
decoded_params = brax_model.load_params(path)

# Compute one action with both the original and the reconstructed parameters
state = fast.reset(jax.random.PRNGKey(0))
Expand Down
13 changes: 10 additions & 3 deletions brax/training/agents/es/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

from absl.testing import absltest
from absl.testing import parameterized
import jax

from brax import envs
from brax.training.acme import running_statistics
from brax.training.agents.es import networks as es_networks
from brax.training.agents.es import train as es
import jax


class ESTest(parameterized.TestCase):
Expand Down Expand Up @@ -54,8 +55,14 @@ def testModelEncoding(self, normalize_observations):
env.observation_size, env.action_size, normalize_fn
)
inference = es_networks.make_inference_fn(es_network)
byte_encoding = pickle.dumps(params)
decoded_params = pickle.loads(byte_encoding)
import tempfile

from brax.io import model as brax_model

with tempfile.TemporaryDirectory() as tmpdir:
path = f'{tmpdir}/params.msgpack'
brax_model.save_params(path, params)
decoded_params = brax_model.load_params(path)

# Compute one action.
state = env.reset(jax.random.PRNGKey(0))
Expand Down
17 changes: 12 additions & 5 deletions brax/training/agents/ppo/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@

import functools
import pickle

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jnp

from brax import envs
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import networks_vision as ppo_networks_vision
from brax.training.agents.ppo import train as ppo
import jax
from jax import numpy as jnp


class PPOTest(parameterized.TestCase):
"""Tests for PPO module."""


@parameterized.parameters('ndarray', 'dict_state')
def testTrain(self, obs_mode):
"""Test PPO with a simple env."""
Expand Down Expand Up @@ -211,8 +212,14 @@ def testNetworkEncoding(self, normalize_observations):
env.observation_size, env.action_size, normalize_fn
)
inference = ppo_networks.make_inference_fn(ppo_network)
byte_encoding = pickle.dumps(params)
decoded_params = pickle.loads(byte_encoding)
import tempfile

from brax.io import model as brax_model

with tempfile.TemporaryDirectory() as tmpdir:
path = f'{tmpdir}/params.msgpack'
brax_model.save_params(path, params)
decoded_params = brax_model.load_params(path)

# Compute one action.
state = env.reset(jax.random.PRNGKey(0))
Expand Down
Loading