Skip to content

Commit 9b24617

Browse files
authored
Speed-up quantiles (#609)
1 parent bdc1726 commit 9b24617

2 files changed

Lines changed: 185 additions & 55 deletions

File tree

pipeline_dp/combiners.py

Lines changed: 101 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,6 @@
3030
ArrayLike = Union[np.ndarray, List[float]]
3131
ExplainComputationReport = Union[Callable, str, List[Union[Callable, str]]]
3232

33-
# This is a workaround which is useful when Quantiles form PyDP are failed to
34-
# be serialized. It was observed only from Google colab. Disabling of the
35-
# proto serialization can work only with LocalBackend.
36-
# Ideally we should find a better solution.
37-
_proto_serialization_disabled = False
38-
39-
40-
def disable_proto_serialization():
41-
global _proto_serialization_disabled
42-
_proto_serialization_disabled = True
43-
4433

4534
class Combiner(abc.ABC):
4635
"""Base class for all combiners.
@@ -626,6 +615,90 @@ def mechanism_spec(self) -> budget_accounting.MechanismSpec:
626615
return self._params.mechanism_spec
627616

628617

618+
class QuantileAccumulator:
619+
"""Accumulator for QuantileCombiner.
620+
621+
It keeps elements in a list up to 1000 elements. Beyond that, it creates a
622+
QuantileTree and adds elements to it. This avoids expensive serialization
623+
of QuantileTree when there are few elements.
624+
"""
625+
626+
TREE_HEIGHT = 4
627+
BRANCHING_FACTOR = 16
628+
MAX_ELEMENTS_IN_LIST = 1000
629+
630+
def __init__(self, min_value: float, max_value: float):
631+
self.min_value = min_value
632+
self.max_value = max_value
633+
self.elements = []
634+
self.tree = None
635+
636+
def add_entry(self, value: float):
637+
if self.tree is not None:
638+
self.tree.add_entry(value)
639+
else:
640+
self.elements.append(value)
641+
if len(self.elements) >= self.MAX_ELEMENTS_IN_LIST:
642+
self.create_tree()
643+
644+
def create_tree(self):
645+
self.tree = quantile_tree.QuantileTree(
646+
self.min_value,
647+
self.max_value,
648+
self.TREE_HEIGHT,
649+
self.BRANCHING_FACTOR,
650+
)
651+
for v in self.elements:
652+
self.tree.add_entry(v)
653+
self.elements = None
654+
655+
def merge(self, other: 'QuantileAccumulator'):
656+
if self.tree is not None and other.tree is not None:
657+
self.tree.merge(
658+
pydp._pydp.bytes_to_summary(other.tree.serialize().to_bytes()))
659+
elif self.tree is not None and other.elements is not None:
660+
for v in other.elements:
661+
self.tree.add_entry(v)
662+
elif self.elements is not None and other.tree is not None:
663+
self.create_tree()
664+
self.tree.merge(
665+
pydp._pydp.bytes_to_summary(other.tree.serialize().to_bytes()))
666+
else:
667+
self.elements.extend(other.elements)
668+
if len(self.elements) >= self.MAX_ELEMENTS_IN_LIST:
669+
self.create_tree()
670+
671+
def __getstate__(self):
672+
if self.tree is not None:
673+
return {
674+
'tree': self.tree.serialize().to_bytes(),
675+
'min_value': self.min_value,
676+
'max_value': self.max_value,
677+
}
678+
else:
679+
return {
680+
'elements': self.elements,
681+
'min_value': self.min_value,
682+
'max_value': self.max_value,
683+
}
684+
685+
def __setstate__(self, state):
686+
self.min_value = state['min_value']
687+
self.max_value = state['max_value']
688+
if 'tree' in state:
689+
self.tree = quantile_tree.QuantileTree(
690+
self.min_value,
691+
self.max_value,
692+
self.TREE_HEIGHT,
693+
self.BRANCHING_FACTOR,
694+
)
695+
self.tree.merge(pydp._pydp.bytes_to_summary(state['tree']))
696+
self.elements = None
697+
else:
698+
self.elements = state['elements']
699+
self.tree = None
700+
701+
629702
class QuantileCombiner(Combiner):
630703
"""Combiner for computing DP quantiles.
631704
@@ -637,41 +710,30 @@ class QuantileCombiner(Combiner):
637710
The accumulator is QuantileTree object serialized to string.
638711
"""
639712

640-
AccumulatorType = Union[bytes, List[float]]
641-
642713
def __init__(self, params, percentiles_to_compute: List[float]):
643714
self._params = params
644715
self._percentiles = percentiles_to_compute
645716
self._quantiles_to_compute = [p / 100 for p in percentiles_to_compute]
646717

647-
def create_accumulator(self, values) -> AccumulatorType:
648-
if _proto_serialization_disabled:
649-
return values
650-
tree = self._create_empty_quantile_tree()
718+
def create_accumulator(self, values) -> QuantileAccumulator:
719+
acc = QuantileAccumulator(
720+
self._params.aggregate_params.min_value,
721+
self._params.aggregate_params.max_value,
722+
)
651723
for value in values:
652-
tree.add_entry(value)
653-
return tree.serialize().to_bytes()
724+
acc.add_entry(value)
725+
return acc
654726

655-
def merge_accumulators(self, accumulator1: AccumulatorType,
656-
accumulator2: AccumulatorType) -> AccumulatorType:
657-
if _proto_serialization_disabled:
658-
return accumulator1 + accumulator2 # union of lists
659-
660-
tree = self._create_empty_quantile_tree()
661-
if accumulator1:
662-
tree.merge(pydp._pydp.bytes_to_summary(accumulator1))
663-
if accumulator2:
664-
tree.merge(pydp._pydp.bytes_to_summary(accumulator2))
665-
return tree.serialize().to_bytes()
666-
667-
def compute_metrics(self, accumulator: AccumulatorType) -> AccumulatorType:
668-
if _proto_serialization_disabled:
669-
tree = self._create_empty_quantile_tree()
670-
for value in accumulator:
671-
tree.add_entry(value)
672-
else:
673-
tree = self._create_empty_quantile_tree()
674-
tree.merge(pydp._pydp.bytes_to_summary(accumulator))
727+
def merge_accumulators(
728+
self, accumulator1: QuantileAccumulator,
729+
accumulator2: QuantileAccumulator) -> QuantileAccumulator:
730+
accumulator1.merge(accumulator2)
731+
return accumulator1
732+
733+
def compute_metrics(self, accumulator: QuantileAccumulator) -> dict:
734+
if accumulator.tree is None:
735+
accumulator.create_tree()
736+
tree = accumulator.tree
675737

676738
quantiles = dp_computations.compute_dp_quantiles(
677739
tree,
@@ -699,16 +761,6 @@ def format_metric_name(p: float):
699761
def explain_computation(self) -> ExplainComputationReport:
700762
return lambda: f"Computed percentiles {self._percentiles} with (eps={self._params.eps} delta={self._params.delta})"
701763

702-
def _create_empty_quantile_tree(self):
703-
# The default tree parameters taken from
704-
# https://github.com/google/differential-privacy/blob/605ec87bcbd4a536995b611132dbf4d341d2e91d/cc/algorithms/quantile-tree.h#L47
705-
DEFAULT_TREE_HEIGHT = 4
706-
DEFAULT_BRANCHING_FACTOR = 16
707-
return quantile_tree.QuantileTree(
708-
self._params.aggregate_params.min_value,
709-
self._params.aggregate_params.max_value, DEFAULT_TREE_HEIGHT,
710-
DEFAULT_BRANCHING_FACTOR)
711-
712764
def mechanism_spec(self) -> budget_accounting.MechanismSpec:
713765
return self._params.mechanism_spec
714766

tests/combiners_test.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,90 @@ def test_compute_metrics_with_noise(self):
779779
0.01) # check that noise is added
780780

781781

782+
class QuantileAccumulatorTest(parameterized.TestCase):
783+
784+
def test_add_entry_keeps_elements(self):
785+
acc = dp_combiners.QuantileAccumulator(min_value=0, max_value=100)
786+
for i in range(10):
787+
acc.add_entry(i)
788+
self.assertIsNone(acc.tree)
789+
self.assertEqual(acc.elements, list(range(10)))
790+
791+
def test_add_entry_creates_tree(self):
792+
acc = dp_combiners.QuantileAccumulator(min_value=0, max_value=1000)
793+
limit = 1000
794+
for i in range(limit + 1):
795+
acc.add_entry(i)
796+
self.assertIsNotNone(acc.tree)
797+
self.assertIsNone(acc.elements)
798+
799+
def test_merge_elements(self):
800+
acc1 = dp_combiners.QuantileAccumulator(min_value=0, max_value=100)
801+
acc1.add_entry(1)
802+
acc2 = dp_combiners.QuantileAccumulator(min_value=0, max_value=100)
803+
acc2.add_entry(2)
804+
805+
acc1.merge(acc2)
806+
self.assertIsNone(acc1.tree)
807+
self.assertEqual(acc1.elements, [1, 2])
808+
809+
def test_merge_tree_and_elements(self):
810+
acc1 = dp_combiners.QuantileAccumulator(min_value=0, max_value=1000)
811+
limit = 1000 # MAX_ELEMENTS_IN_QUANTILE_ACCUMULATOR
812+
for i in range(limit):
813+
acc1.add_entry(i)
814+
self.assertIsNotNone(acc1.tree)
815+
816+
acc2 = dp_combiners.QuantileAccumulator(min_value=0, max_value=1000)
817+
acc2.add_entry(1000)
818+
819+
acc1.merge(acc2)
820+
self.assertIsNotNone(acc1.tree)
821+
822+
def test_merge_elements_and_tree(self):
823+
acc1 = dp_combiners.QuantileAccumulator(min_value=0, max_value=1000)
824+
acc1.add_entry(1)
825+
826+
acc2 = dp_combiners.QuantileAccumulator(min_value=0, max_value=1000)
827+
limit = 1000 # MAX_ELEMENTS_IN_QUANTILE_ACCUMULATOR
828+
for i in range(limit):
829+
acc2.add_entry(i)
830+
self.assertIsNotNone(acc2.tree)
831+
832+
acc1.merge(acc2)
833+
self.assertIsNotNone(acc1.tree)
834+
835+
def test_serialization_elements(self):
836+
acc = dp_combiners.QuantileAccumulator(min_value=0, max_value=100)
837+
acc.add_entry(1)
838+
acc.add_entry(2)
839+
840+
state = acc.__getstate__()
841+
self.assertIn('elements', state)
842+
self.assertNotIn('tree', state)
843+
844+
acc2 = dp_combiners.QuantileAccumulator(min_value=0, max_value=100)
845+
acc2.__setstate__(state)
846+
self.assertEqual(acc2.elements, [1, 2])
847+
self.assertIsNone(acc2.tree)
848+
849+
def test_serialization_tree(self):
850+
acc = dp_combiners.QuantileAccumulator(min_value=0, max_value=1000)
851+
limit = 1000
852+
for i in range(limit):
853+
acc.add_entry(i)
854+
self.assertIsNotNone(acc.tree)
855+
856+
state = acc.__getstate__()
857+
self.assertIn('tree', state)
858+
self.assertNotIn('elements', state)
859+
860+
acc2 = dp_combiners.QuantileAccumulator(min_value=0, max_value=1000)
861+
acc2.__setstate__(state)
862+
self.assertIsNotNone(acc2.tree)
863+
self.assertIsNone(acc2.elements)
864+
865+
782866
class QuantileCombinerTest(parameterized.TestCase):
783867

784868
def _create_combiner(self,
@@ -790,12 +874,6 @@ def _create_combiner(self,
790874
return dp_combiners.QuantileCombiner(params,
791875
percentiles_to_compute=percentiles)
792876

793-
def test_create_accumulator(self):
794-
combiner = self._create_combiner(no_noise=False)
795-
quantile_tree = combiner._create_empty_quantile_tree()
796-
self.assertEqual(16, quantile_tree.branching_factor) # default value
797-
self.assertEqual(4, quantile_tree.height) # default value
798-
799877
def test_compute_metrics_without_merge(self):
800878
# Arrange.
801879
combiner = self._create_combiner(no_noise=True,

0 commit comments

Comments
 (0)