Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@ models, including Large Language Models (LLMs).
[Keras API simple example](https://github.com/google-deepmind/jax_privacy/blob/main/examples/keras_api_example.py)
and
[Gemma fine-tuning notebook](https://github.com/google-deepmind/jax_privacy/blob/main/examples/dp_sgd_keras_gemma3_lora_finetuning_samsum.ipynb)
to get started.
to get started. The wrapper supports both standard fixed-size batches
and opt-in internal Poisson sampling, and the configured privacy
accounting should match the training loop's sampling semantics.
`train_steps` counts optimizer updates, so gradient accumulation needs
to be reflected there, and `validation_split` is intentionally
unsupported because the privacy accountant needs the exact post-split
training-set size. Generator-like inputs also need an explicit
`steps_per_epoch` when their length cannot be inferred.
* **Flax Linen**: Offers greater flexibility for custom model
architectures and training loops, at the cost of some additional
boilerplate. See
Expand Down
25 changes: 22 additions & 3 deletions docs/keras_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,28 @@ example below shows that.
This section demonstrates how to integrate the Keras API into a typical
Keras training workflow.

The example below enables ``poisson_sampling_in_fit`` and passes training data
to ``fit()`` as per-example arrays. In that setup, the DP Keras wrapper draws
Poisson-sampled batches internally from those arrays.
The example below uses standard fixed-size batches and sets
``sampling_method=SamplingMethod.FIXED_BATCH_SIZE`` so the privacy accountant
matches the actual training loop.

If you instead want the wrapper to resample random-access array inputs with
Poisson sampling inside ``fit()``, enable ``poisson_sampling_in_fit=True``. In
that mode the wrapper uses Poisson accounting automatically.

For dataset or generator inputs, the wrapper cannot infer the sampling
semantics automatically, so ``sampling_method`` must be set explicitly when
``poisson_sampling_in_fit`` is disabled. Generator-like inputs whose length
cannot be inferred also need an explicit ``steps_per_epoch`` so the wrapper can
bound the privacy budget before training starts.

When ``gradient_accumulation_steps > 1``, ``train_steps`` counts optimizer
updates rather than physical minibatches. In practice, this means you should
divide the total number of minibatches your training loop will execute by
``gradient_accumulation_steps`` and round down.

``validation_split`` is not supported for DP Keras training. Create the
training/validation split explicitly so ``train_size`` matches the exact number
of training examples seen by the privacy accountant.

.. literalinclude:: ../examples/keras_api_example.py
:language: python
Expand Down
8 changes: 4 additions & 4 deletions docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Then on top of the core library the following backend-specific public high-level
* [Keras](https://github.com/google-deepmind/jax_privacy/tree/main/jax_privacy/keras_api.py)

These APIs abstract some complexity and reduce the amount of code necessary to
implement DP training at the cost of less flexibility. Currently, the only
supported mechanism available when using the Keras API is DP-SGD with
internally Poisson-sampled batches built from random-access per-example arrays
(with accounting done using the same Poisson-sampling assumption).
implement DP training at the cost of less flexibility. The Keras API currently
supports DP-SGD for both standard fixed-size batches and opt-in internal
Poisson sampling from random-access per-example arrays, with the privacy
accounting aligned to the chosen sampling semantics.
11 changes: 7 additions & 4 deletions examples/jax_api_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
that is generated using a known w and b. The goal is to learn w and b from the
synthetic dataset and compare the learned parameters with the known w and b.

The expected final loss should be very close to zero, ~0.0005 and the learned
w and b should be very close to the true w and b (max absolute error should be
smaller than 0.3).
The expected final loss should decrease substantially over training and the
learned w and b should move toward the true w and b. Under the DP setting in
this example, the fixed-batch accounting assumption yields a noticeably harsher
noise level than the previous Poisson default, so the learned parameters are
not expected to match the non-DP run as closely.
"""

from typing import Any, Mapping, Tuple
Expand Down Expand Up @@ -126,6 +128,7 @@ def main(_):
num_updates=num_epochs * train_size // batch_size,
num_samples=train_size,
target_delta=1e-5,
sampling_method=analysis.SamplingMethod.FIXED_BATCH_SIZE,
)
noise_rng = random.key(42)
grad_and_value_fn = jax_privacy.clipped_grad(
Expand Down Expand Up @@ -213,7 +216,7 @@ def update_step(
print(f"True parameters: w={true_w:.4f}, b={true_b:.4f}")

if use_dp:
assert abs(model_params["w"] - true_w) < 0.6, "w is too far from true_w!"
assert abs(model_params["w"] - true_w) < 1.0, "w is too far from true_w!"
assert abs(model_params["b"] - true_b) < 0.6, "b is too far from true_b!"
else:
assert abs(model_params["w"] - true_w) < 0.1, "w is too far from true_w!"
Expand Down
45 changes: 35 additions & 10 deletions examples/keras_api_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

os.environ["KERAS_BACKEND"] = "jax"
# pylint: disable=g-import-not-at-top,wrong-import-position
from jax_privacy.accounting import analysis
from jax_privacy import keras_api
import keras
from keras import layers
Expand Down Expand Up @@ -77,6 +78,7 @@ def main(_):
batch_size = 128
train_size = (len(x_train) // batch_size) * batch_size
x_train, y_train = x_train[:train_size], y_train[:train_size]
keras.utils.set_random_seed(0)
model = get_model()

epsilon = 1.1
Expand All @@ -93,7 +95,7 @@ def main(_):
batch_size=batch_size,
train_steps=epochs * (train_size // batch_size),
train_size=train_size,
poisson_sampling_in_fit=True,
sampling_method=analysis.SamplingMethod.FIXED_BATCH_SIZE,
seed=0,
gradient_accumulation_steps=1,
)
Expand All @@ -102,8 +104,7 @@ def main(_):
f"DP training:{epsilon=} {delta=} {clipping_norm=} {batch_size=} "
f"{epochs=} {train_size=}"
)
# This example opts into internal Poisson sampling from the per-example
# arrays passed to fit().
print("Using fixed-size batches with fixed-batch accounting.")
else:
print("Non-DP training")
model.compile(
Expand All @@ -114,19 +115,43 @@ def main(_):
fit_kwargs = dict(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
)
if not dp:
fit_kwargs["batch_size"] = batch_size
history = model.fit(**fit_kwargs)
model.fit(**fit_kwargs)
train_metrics = model.evaluate(
x_train,
y_train,
batch_size=batch_size,
verbose=0,
return_dict=True,
)
val_metrics = model.evaluate(
x_test,
y_test,
batch_size=batch_size,
verbose=0,
return_dict=True,
)
# [END example]
print("DP: expected train accuracy: ~96%, val accuracy: ~92%")
print("Non-DP: expected train accuracy: ~98%, val accuracy: ~98%")
final_accuracy = history.history["accuracy"][-1]
print(
"DP: expected evaluated train accuracy: >60%,"
" evaluated val accuracy depends on epsilon"
)
print(
"Non-DP: expected evaluated train accuracy: ~98%,"
" evaluated val accuracy: ~98%"
)
print(
"Final evaluated metrics:"
f" train_accuracy={train_metrics['accuracy']:.4f},"
f" val_accuracy={val_metrics['accuracy']:.4f}"
)
final_accuracy = train_metrics["accuracy"]
if dp:
assert (
final_accuracy > 0.85
final_accuracy > 0.60
), f"DP Accuracy {final_accuracy:.4f} is too low!"
else:
assert (
Expand Down
15 changes: 15 additions & 0 deletions jax_privacy/accounting/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def calibrate_num_updates(
target_delta: float,
examples_per_user: int | None = None,
cycle_length: int | None = None,
sampling_method: analysis.SamplingMethod = analysis.SamplingMethod.POISSON,
truncated_batch_size: int | None = None,
initial_max_updates: int = 4,
initial_min_updates: int = 1,
Expand All @@ -74,6 +75,9 @@ def calibrate_num_updates(
maximum number any user contributes to the training set.
cycle_length: If using cyclic Poisson sampling with BandMF, the length of
the cycle.
sampling_method: The sampling method assumed by the privacy analysis.
Defaults to `SamplingMethod.POISSON`; callers using fixed-size batches
should pass `SamplingMethod.FIXED_BATCH_SIZE`.
truncated_batch_size: If using truncated Poisson sampling, the maximum batch
size to truncate to.
initial_max_updates: An initial estimate of the number of updates.
Expand All @@ -94,6 +98,7 @@ def get_epsilon(num_updates: int) -> float:
delta=target_delta,
examples_per_user=examples_per_user,
cycle_length=cycle_length,
sampling_method=sampling_method,
truncated_batch_size=truncated_batch_size,
)
return accountant.compute_epsilon(num_updates, dp_params)
Expand Down Expand Up @@ -132,6 +137,7 @@ def calibrate_noise_multiplier(
target_delta: float,
examples_per_user: int | None = None,
cycle_length: int | None = None,
sampling_method: analysis.SamplingMethod = analysis.SamplingMethod.POISSON,
truncated_batch_size: int | None = None,
initial_max_noise: float = 1.0,
initial_min_noise: float = 0.0,
Expand All @@ -152,6 +158,9 @@ def calibrate_noise_multiplier(
maximum number any user contributes to the training set.
cycle_length: If using cyclic Poisson sampling with BandMF, the length of
the cycle.
sampling_method: The sampling method assumed by the privacy analysis.
Defaults to `SamplingMethod.POISSON`; callers using fixed-size batches
should pass `SamplingMethod.FIXED_BATCH_SIZE`.
truncated_batch_size: If using truncated Poisson sampling, the maximum batch
size to truncate to.
initial_max_noise: An initial estimate of the noise multiplier.
Expand All @@ -174,6 +183,7 @@ def get_epsilon(noise_multiplier: float) -> float:
delta=target_delta,
examples_per_user=examples_per_user,
cycle_length=cycle_length,
sampling_method=sampling_method,
truncated_batch_size=truncated_batch_size,
)
return accountant.compute_epsilon(num_updates, dp_params)
Expand Down Expand Up @@ -201,6 +211,7 @@ def calibrate_batch_size(
target_delta: float,
examples_per_user: int | None = None,
cycle_length: int | None = None,
sampling_method: analysis.SamplingMethod = analysis.SamplingMethod.POISSON,
truncated_batch_size: int | None = None,
initial_max_batch_size: int = 8,
initial_min_batch_size: int = 1,
Expand All @@ -221,6 +232,9 @@ def calibrate_batch_size(
maximum number any user contributes to the training set.
cycle_length: If using cyclic Poisson sampling with BandMF, the length of
the cycle.
sampling_method: The sampling method assumed by the privacy analysis.
Defaults to `SamplingMethod.POISSON`; callers using fixed-size batches
should pass `SamplingMethod.FIXED_BATCH_SIZE`.
truncated_batch_size: If using truncated Poisson sampling, the maximum batch
size to truncate to.
initial_max_batch_size: An initial estimate of the batch size.
Expand All @@ -243,6 +257,7 @@ def get_epsilon(batch_size: int) -> float:
delta=target_delta,
examples_per_user=examples_per_user,
cycle_length=cycle_length,
sampling_method=sampling_method,
truncated_batch_size=truncated_batch_size,
)
return accountant.compute_epsilon(num_updates, dp_params)
Expand Down
30 changes: 30 additions & 0 deletions jax_privacy/batch_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ def split_and_pad_global_batch(
The last minibatch may contain extra `-1` indices representing padding
examples to make it the right size.
"""
if minibatch_size <= 0:
raise ValueError(f'minibatch_size must be positive, got {minibatch_size}.')
if microbatch_size is not None and microbatch_size <= 0:
raise ValueError(
f'microbatch_size must be positive when set, got {microbatch_size}.'
)
sections = range(minibatch_size, indices.shape[0], minibatch_size)
minibatches = np.array_split(indices, sections, axis=0)
minibatch_shape = (minibatch_size,) + indices.shape[1:]
Expand Down Expand Up @@ -239,6 +245,16 @@ class CyclicPoissonSampling(BatchSelectionStrategy):
cycle_length: int = 1
partition_type: PartitionType = PartitionType.EQUAL_SPLIT

def __post_init__(self):
if not 0 <= self.sampling_prob <= 1:
raise ValueError('sampling_prob must be in [0, 1].')
if self.iterations < 0:
raise ValueError('iterations must be non-negative.')
if self.cycle_length <= 0:
raise ValueError('cycle_length must be positive.')
if self.truncated_batch_size is not None and self.truncated_batch_size < 0:
raise ValueError('truncated_batch_size must be non-negative.')

def batch_iterator(
self, num_examples: int, rng: RngType = None
) -> Iterator[np.ndarray]:
Expand Down Expand Up @@ -290,6 +306,12 @@ class BallsInBinsSampling(BatchSelectionStrategy):
iterations: int
cycle_length: int

def __post_init__(self):
if self.iterations < 0:
raise ValueError('iterations must be non-negative.')
if self.cycle_length <= 0:
raise ValueError('cycle_length must be positive.')

def batch_iterator(
self, num_examples: int, rng: RngType = None
) -> Iterator[np.ndarray]:
Expand Down Expand Up @@ -322,6 +344,10 @@ class FixedBatchSampling(BatchSelectionStrategy):
iterations: int
replace: bool = False

def __post_init__(self):
if self.iterations < 0:
raise ValueError('iterations must be non-negative.')

def batch_iterator(
self, num_examples: int, rng: RngType = None
) -> Iterator[np.ndarray]:
Expand Down Expand Up @@ -482,6 +508,10 @@ class UserSelectionStrategy:
examples_per_user_per_batch: int = 1
shuffle_per_user: bool = False

def __post_init__(self):
if self.examples_per_user_per_batch <= 0:
raise ValueError('examples_per_user_per_batch must be positive.')

def batch_iterator(
self, user_ids: np.ndarray, rng: RngType = None
) -> Iterator[np.ndarray]:
Expand Down
Loading