Skip to content

Commit 8cc5df1

Browse files
Merge pull request #3877 from AI-Hypercomputer:bvandermoon-remediate-rce-deserialization
PiperOrigin-RevId: 914493298
2 parents 4d9f390 + d640cad commit 8cc5df1

16 files changed

Lines changed: 115 additions & 38 deletions

src/maxtext/checkpoint_conversion/standalone_scripts/llama4_ckpt_unscanned.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,9 +600,8 @@ def _convert_pytorch_to_jax_weights(base_model_path: str, model_size: str, model
600600
for i, ckpt_path in enumerate(ckpt_paths):
601601
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
602602
# NOTE: starting in PT2.6, `weights_only` was switched from the default of `False` to `True`
603-
# thus we need to specify this or else loading will fail
604603
chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = torch.load(
605-
ckpt_path, map_location="cpu", weights_only=False
604+
ckpt_path, map_location="cpu", weights_only=True
606605
)
607606
chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))]
608607
# map weight names if they use HuggingFace instead of PyTorch convention

src/maxtext/checkpoint_conversion/standalone_scripts/llama_ckpt_conversion_inference_only.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def convert(base_model_path, maxtext_model_path, model_size):
157157
for i, ckpt_path in enumerate(ckpt_paths):
158158
print(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
159159

160-
checkpoint = torch.load(ckpt_path, map_location="cpu")
160+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
161161
pytorch_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint
162162
print("memory usage in GB: ", psutil.Process().memory_info().rss / (1024 * 1024))
163163

src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def convert_lora_weights_to_jax_weights(lora_config: dict, model_size: str):
428428

429429
max_logging.log(f"Loading the lora model from {lora_config['lora_model_path']}")
430430
# Load LoRA model weights
431-
lora_chkpt_vars = torch.load(lora_config["lora_model_path"])
431+
lora_chkpt_vars = torch.load(lora_config["lora_model_path"], weights_only=True)
432432
lora_chkpt_vars = _NamespaceMapper(lora_chkpt_vars)
433433

434434
jax_weights_lora = {
@@ -1112,9 +1112,8 @@ def _convert_pytorch_to_jax_weights(base_model_path: str, model_size: str, model
11121112
for i, ckpt_path in enumerate(ckpt_paths):
11131113
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
11141114
# NOTE: starting in PT2.6, `weights_only` was switched from the default of `False` to `True`
1115-
# thus we need to specify this or else loading will fail
11161115
chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = torch.load(
1117-
ckpt_path, map_location="cpu", weights_only=False
1116+
ckpt_path, map_location="cpu", weights_only=True
11181117
)
11191118
chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))]
11201119
# map weight names if they use HuggingFace instead of PyTorch convention

src/maxtext/inference/mlperf/evaluate-accuracy-fast.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import json
2020
import nltk
2121
import numpy as np
22+
import os
2223
import pandas as pd
2324
import tqdm
2425

@@ -74,7 +75,19 @@ def get_args():
7475

7576

7677
def get_groundtruth(processed_dataset_file):
77-
data = pd.read_pickle(processed_dataset_file)
78+
"""Load the ground truth labels from the processed dataset file securely."""
79+
ext = os.path.splitext(processed_dataset_file)[1].lower()
80+
if ext == ".parquet":
81+
data = pd.read_parquet(processed_dataset_file)
82+
elif ext == ".csv":
83+
data = pd.read_csv(processed_dataset_file)
84+
elif ext in (".json", ".jsonl"):
85+
data = pd.read_json(processed_dataset_file)
86+
else:
87+
raise ValueError(
88+
f"Unsupported dataset file format: {processed_dataset_file}. "
89+
"Please use safe formats like Parquet (.parquet), CSV (.csv), or JSON/JSONL (.json/.jsonl)."
90+
)
7891
return data["output"]
7992

8093

src/maxtext/inference/mlperf/evaluate-accuracy.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import numpy as np
2626

27+
import os
2728
import pandas as pd
2829

2930

@@ -39,7 +40,19 @@ def get_args():
3940

4041

4142
def get_groundtruth(processed_dataset_file):
42-
data = pd.read_pickle(processed_dataset_file)
43+
"""Load the ground truth labels from the processed dataset file securely."""
44+
ext = os.path.splitext(processed_dataset_file)[1].lower()
45+
if ext == ".parquet":
46+
data = pd.read_parquet(processed_dataset_file)
47+
elif ext == ".csv":
48+
data = pd.read_csv(processed_dataset_file)
49+
elif ext in (".json", ".jsonl"):
50+
data = pd.read_json(processed_dataset_file)
51+
else:
52+
raise ValueError(
53+
f"Unsupported dataset file format: {processed_dataset_file}. "
54+
"Please use safe formats like Parquet (.parquet), CSV (.csv), or JSON/JSONL (.json/.jsonl)."
55+
)
4356
ground_truths = data["output"]
4457
return ground_truths
4558

src/maxtext/inference/mlperf/offline_mode.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,18 @@ def main():
375375
log.info("Mlperf config: %s", args.mlperf_conf)
376376
log.info("User config: %s", user_conf)
377377

378-
log.info("dataset path: %s", args.dataset_path)
379-
dataset = pd.read_pickle(args.dataset_path)
378+
ext = os.path.splitext(args.dataset_path)[1].lower()
379+
if ext == ".parquet":
380+
dataset = pd.read_parquet(args.dataset_path)
381+
elif ext == ".csv":
382+
dataset = pd.read_csv(args.dataset_path)
383+
elif ext in (".json", ".jsonl"):
384+
dataset = pd.read_json(args.dataset_path)
385+
else:
386+
raise ValueError(
387+
f"Unsupported dataset file format: {args.dataset_path}. "
388+
"Please use safe formats like Parquet (.parquet), CSV (.csv), or JSON/JSONL (.json/.jsonl)."
389+
)
380390
if args.rename_dataset_cols:
381391
rename_dict = json.loads(args.rename_dataset_cols)
382392
dataset.rename(columns=rename_dict, inplace=True)

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020

2121
import abc
22-
import pickle
22+
import safetensors.numpy
2323
from typing import Any, Callable, Iterator, List, Literal, Optional, Sequence
2424

2525
import flax
@@ -110,7 +110,7 @@ def __next__(self):
110110

111111
record = self.reader.read()
112112
self.record_index += 1
113-
data = pickle.loads(record)
113+
data = safetensors.numpy.load(record)
114114

115115
# Map the arrays to match MaxText's expected dictionary
116116
batch = {

src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525

2626
import os
27-
import pickle
27+
import safetensors.numpy
2828
from typing import Sequence
2929
import argparse
3030
import time
@@ -165,7 +165,7 @@ def generate_and_save_data(config, local_args):
165165
if key in batch:
166166
record_dict[key] = jax.device_get(batch[key])
167167

168-
writer.write(pickle.dumps(record_dict))
168+
writer.write(safetensors.numpy.save(record_dict))
169169

170170
if step % 50 == 0:
171171
max_logging.log(f"Successfully processed step {step} in {time.time() - step_start:.4f}s")

src/maxtext/trainers/post_train/distillation/verify_saved_logits.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import sys
2626

2727
import argparse
28-
import pickle
28+
import safetensors.numpy
2929
from absl import app
3030
import tensorflow as tf
3131
from array_record.python import array_record_module
@@ -57,7 +57,7 @@ def verify_array_records(output_dir, expected_steps, expected_k, expected_keys):
5757

5858
for record_idx in range(num_records_in_file):
5959
record = reader.read()
60-
data = pickle.loads(record)
60+
data = safetensors.numpy.load(record)
6161

6262
# Verify all required keys are present
6363
required_keys = ["tokens", "top_k_logits", "top_k_indices"]

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import functools
2525
import os
26-
import pickle
2726
from typing import Sequence
2827

2928
from absl import app
@@ -181,7 +180,7 @@ def save_compiled(compiled, save_name):
181180
"""Serialize and save the compiled function."""
182181
serialized, _, _ = serialize(compiled)
183182
with open(save_name, "wb") as f:
184-
pickle.dump(serialized, f)
183+
f.write(serialized)
185184

186185

187186
def is_oom(argv: Sequence[str]) -> bool:

0 commit comments

Comments
 (0)