Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions mmv/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,44 @@

"""Checkpoint restoring utilities."""

import pickle

from absl import logging
import dill


class _SafeUnpickler(pickle.Unpickler):
"""Restricts unpickling to numpy arrays and basic Python types.

Replaces dill.load() to prevent arbitrary code execution when loading
checkpoints from untrusted sources. Only the types present in MMV
parameter/state dicts (nested dicts of numpy arrays) are allowed.
"""

_ALLOWED = {
"numpy": {"ndarray", "dtype"},
"numpy.core.multiarray": {"_reconstruct", "scalar"},
"numpy.core": {"multiarray"},
"numpy.dtypes": {"Float32DType", "Float64DType", "Int32DType",
"Int64DType"},
"builtins": {
"dict", "list", "tuple", "set", "str", "int", "float",
"bool", "bytes", "complex",
},
}

def find_class(self, module, name):
allowed_names = self._ALLOWED.get(module, set())
if name in allowed_names:
return super().find_class(module, name)
raise pickle.UnpicklingError(
f"Refusing to unpickle {module}.{name}: not in allowlist.")


def load_checkpoint(checkpoint_path):
try:
with open(checkpoint_path, 'rb') as checkpoint_file:
checkpoint_data = dill.load(checkpoint_file)
logging.info('Loading checkpoint from %s', checkpoint_path)
with open(checkpoint_path, "rb") as checkpoint_file:
checkpoint_data = _SafeUnpickler(checkpoint_file).load()
logging.info("Loading checkpoint from %s", checkpoint_path)
return checkpoint_data
except FileNotFoundError:
return None