|
6 | 6 |
|
7 | 7 | # pyre-unsafe |
8 | 8 |
|
| 9 | +import io |
9 | 10 | import json |
10 | 11 | import os |
11 | | -import pickle |
12 | 12 | from typing import BinaryIO, Dict, IO, List, Optional, Union |
13 | 13 | from zipfile import BadZipFile, ZipFile |
14 | 14 |
|
@@ -228,15 +228,19 @@ def _save_metadata(self, etrecord_zip: ZipFile) -> None: |
228 | 228 | ) |
229 | 229 |
|
230 | 230 | if self._reference_outputs is not None: |
| 231 | + buf = io.BytesIO() |
| 232 | + torch.save(self._reference_outputs, buf) |
231 | 233 | etrecord_zip.writestr( |
232 | 234 | ETRecordReservedFileNames.REFERENCE_OUTPUTS, |
233 | | - pickle.dumps(self._reference_outputs), |
| 235 | + buf.getvalue(), |
234 | 236 | ) |
235 | 237 |
|
236 | 238 | if self._representative_inputs is not None: |
| 239 | + buf = io.BytesIO() |
| 240 | + torch.save(self._representative_inputs, buf) |
237 | 241 | etrecord_zip.writestr( |
238 | 242 | ETRecordReservedFileNames.REPRESENTATIVE_INPUTS, |
239 | | - pickle.dumps(self._representative_inputs), |
| 243 | + buf.getvalue(), |
240 | 244 | ) |
241 | 245 |
|
242 | 246 | if self.export_graph_id is not None: |
@@ -828,15 +832,37 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 |
828 | 832 | ) |
829 | 833 | exported_program = deserialize(serialized_artifact) |
830 | 834 | elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS: |
831 | | - # @lint-ignore PYTHONPICKLEISBAD |
832 | | - reference_outputs = pickle.loads( |
833 | | - etrecord_zip.read(ETRecordReservedFileNames.REFERENCE_OUTPUTS) |
834 | | - ) |
| 835 | + try: |
| 836 | + reference_outputs = torch.load( |
| 837 | + io.BytesIO( |
| 838 | + etrecord_zip.read(ETRecordReservedFileNames.REFERENCE_OUTPUTS) |
| 839 | + ), |
| 840 | + weights_only=True, |
| 841 | + ) |
| 842 | + except Exception as e: |
| 843 | + raise RuntimeError( |
| 844 | + "Failed to load reference_outputs from ETRecord. " |
| 845 | + "This ETRecord file may have been created with an older " |
| 846 | + "version that used pickle serialization. Please regenerate " |
| 847 | + "the ETRecord file with the current version of ExecuTorch." |
| 848 | + ) from e |
835 | 849 | elif entry == ETRecordReservedFileNames.REPRESENTATIVE_INPUTS: |
836 | | - # @lint-ignore PYTHONPICKLEISBAD |
837 | | - representative_inputs = pickle.loads( |
838 | | - etrecord_zip.read(ETRecordReservedFileNames.REPRESENTATIVE_INPUTS) |
839 | | - ) |
| 850 | + try: |
| 851 | + representative_inputs = torch.load( |
| 852 | + io.BytesIO( |
| 853 | + etrecord_zip.read( |
| 854 | + ETRecordReservedFileNames.REPRESENTATIVE_INPUTS |
| 855 | + ) |
| 856 | + ), |
| 857 | + weights_only=True, |
| 858 | + ) |
| 859 | + except Exception as e: |
| 860 | + raise RuntimeError( |
| 861 | + "Failed to load representative_inputs from ETRecord. " |
| 862 | + "This ETRecord file may have been created with an older " |
| 863 | + "version that used pickle serialization. Please regenerate " |
| 864 | + "the ETRecord file with the current version of ExecuTorch." |
| 865 | + ) from e |
840 | 866 | elif entry == ETRecordReservedFileNames.EXPORT_GRAPH_ID: |
841 | 867 | export_graph_id = json.loads( |
842 | 868 | etrecord_zip.read(ETRecordReservedFileNames.EXPORT_GRAPH_ID) |
|
0 commit comments