Skip to content

Commit e729fa7

Browse files
authored
Fix unsafe deserialization in ETRecord and export_serialize
Differential Revision: D96389611 Pull Request resolved: #18133
1 parent 4f900b2 commit e729fa7

2 files changed

Lines changed: 38 additions & 12 deletions

File tree

devtools/etrecord/_etrecord.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
# pyre-unsafe
88

9+
import io
910
import json
1011
import os
11-
import pickle
1212
from typing import BinaryIO, Dict, IO, List, Optional, Union
1313
from zipfile import BadZipFile, ZipFile
1414

@@ -228,15 +228,19 @@ def _save_metadata(self, etrecord_zip: ZipFile) -> None:
228228
)
229229

230230
if self._reference_outputs is not None:
231+
buf = io.BytesIO()
232+
torch.save(self._reference_outputs, buf)
231233
etrecord_zip.writestr(
232234
ETRecordReservedFileNames.REFERENCE_OUTPUTS,
233-
pickle.dumps(self._reference_outputs),
235+
buf.getvalue(),
234236
)
235237

236238
if self._representative_inputs is not None:
239+
buf = io.BytesIO()
240+
torch.save(self._representative_inputs, buf)
237241
etrecord_zip.writestr(
238242
ETRecordReservedFileNames.REPRESENTATIVE_INPUTS,
239-
pickle.dumps(self._representative_inputs),
243+
buf.getvalue(),
240244
)
241245

242246
if self.export_graph_id is not None:
@@ -828,15 +832,37 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
828832
)
829833
exported_program = deserialize(serialized_artifact)
830834
elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS:
831-
# @lint-ignore PYTHONPICKLEISBAD
832-
reference_outputs = pickle.loads(
833-
etrecord_zip.read(ETRecordReservedFileNames.REFERENCE_OUTPUTS)
834-
)
835+
try:
836+
reference_outputs = torch.load(
837+
io.BytesIO(
838+
etrecord_zip.read(ETRecordReservedFileNames.REFERENCE_OUTPUTS)
839+
),
840+
weights_only=True,
841+
)
842+
except Exception as e:
843+
raise RuntimeError(
844+
"Failed to load reference_outputs from ETRecord. "
845+
"This ETRecord file may have been created with an older "
846+
"version that used pickle serialization. Please regenerate "
847+
"the ETRecord file with the current version of ExecuTorch."
848+
) from e
835849
elif entry == ETRecordReservedFileNames.REPRESENTATIVE_INPUTS:
836-
# @lint-ignore PYTHONPICKLEISBAD
837-
representative_inputs = pickle.loads(
838-
etrecord_zip.read(ETRecordReservedFileNames.REPRESENTATIVE_INPUTS)
839-
)
850+
try:
851+
representative_inputs = torch.load(
852+
io.BytesIO(
853+
etrecord_zip.read(
854+
ETRecordReservedFileNames.REPRESENTATIVE_INPUTS
855+
)
856+
),
857+
weights_only=True,
858+
)
859+
except Exception as e:
860+
raise RuntimeError(
861+
"Failed to load representative_inputs from ETRecord. "
862+
"This ETRecord file may have been created with an older "
863+
"version that used pickle serialization. Please regenerate "
864+
"the ETRecord file with the current version of ExecuTorch."
865+
) from e
840866
elif entry == ETRecordReservedFileNames.EXPORT_GRAPH_ID:
841867
export_graph_id = json.loads(
842868
etrecord_zip.read(ETRecordReservedFileNames.EXPORT_GRAPH_ID)

exir/serde/export_serialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def deserialize_torch_artifact(
333333
return {}
334334
buffer = io.BytesIO(serialized)
335335
buffer.seek(0)
336-
artifact = torch.load(buffer)
336+
artifact = torch.load(buffer, weights_only=True)
337337
assert isinstance(artifact, (tuple, dict))
338338
return artifact
339339

0 commit comments

Comments
 (0)