From b77db0ab0e6bebb20647a6bb7dbde6ce51f97f98 Mon Sep 17 00:00:00 2001 From: Martin Brodeur <63083086+brodmart@users.noreply.github.com> Date: Thu, 30 Apr 2026 13:46:21 -0400 Subject: [PATCH] security: replace dill.load with SafeUnpickler allowlist dill.load() is equivalent to pickle.load() and executes arbitrary Python code in any loaded checkpoint file. A malicious or compromised checkpoint at --checkpoint_path will achieve full RCE on the loading host. Replace with _SafeUnpickler, a stdlib-pickle subclass that restricts find_class() to only the types present in MMV params/state dicts (nested dicts of numpy arrays). No dill dependency needed. --- mmv/utils/checkpoint.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/mmv/utils/checkpoint.py b/mmv/utils/checkpoint.py index e8ff30dd..cec1d02a 100644 --- a/mmv/utils/checkpoint.py +++ b/mmv/utils/checkpoint.py @@ -15,15 +15,44 @@ """Checkpoint restoring utilities.""" +import pickle + from absl import logging -import dill + + +class _SafeUnpickler(pickle.Unpickler): + """Restricts unpickling to numpy arrays and basic Python types. + + Replaces dill.load() to prevent arbitrary code execution when loading + checkpoints from untrusted sources. Only the types present in MMV + parameter/state dicts (nested dicts of numpy arrays) are allowed. + """ + + _ALLOWED = { + "numpy": {"ndarray", "dtype"}, + "numpy.core.multiarray": {"_reconstruct", "scalar"}, + "numpy.core": {"multiarray"}, + "numpy.dtypes": {"Float32DType", "Float64DType", "Int32DType", + "Int64DType"}, + "builtins": { + "dict", "list", "tuple", "set", "str", "int", "float", + "bool", "bytes", "complex", + }, + } + + def find_class(self, module, name): + allowed_names = self._ALLOWED.get(module, set()) + if name in allowed_names: + return super().find_class(module, name) + raise pickle.UnpicklingError( + f"Refusing to unpickle {module}.{name}: not in allowlist.") def load_checkpoint(checkpoint_path): try: - with open(checkpoint_path, 'rb') as checkpoint_file: - checkpoint_data = dill.load(checkpoint_file) - logging.info('Loading checkpoint from %s', checkpoint_path) + with open(checkpoint_path, "rb") as checkpoint_file: + checkpoint_data = _SafeUnpickler(checkpoint_file).load() + logging.info("Loading checkpoint from %s", checkpoint_path) return checkpoint_data except FileNotFoundError: return None