3030ArrayLike = Union [np .ndarray , List [float ]]
3131ExplainComputationReport = Union [Callable , str , List [Union [Callable , str ]]]
3232
33- # This is a workaround which is useful when Quantiles form PyDP are failed to
34- # be serialized. It was observed only from Google colab. Disabling of the
35- # proto serialization can work only with LocalBackend.
36- # Ideally we should find a better solution.
37- _proto_serialization_disabled = False
38-
39-
40- def disable_proto_serialization ():
41- global _proto_serialization_disabled
42- _proto_serialization_disabled = True
43-
4433
4534class Combiner (abc .ABC ):
4635 """Base class for all combiners.
@@ -626,6 +615,90 @@ def mechanism_spec(self) -> budget_accounting.MechanismSpec:
626615 return self ._params .mechanism_spec
627616
628617
618+ class QuantileAccumulator :
619+ """Accumulator for QuantileCombiner.
620+
621+ It keeps elements in a list up to 1000 elements. Beyond that, it creates a
622+ QuantileTree and adds elements to it. This avoids expensive serialization
623+ of QuantileTree when there are few elements.
624+ """
625+
626+ TREE_HEIGHT = 4
627+ BRANCHING_FACTOR = 16
628+ MAX_ELEMENTS_IN_LIST = 1000
629+
630+ def __init__ (self , min_value : float , max_value : float ):
631+ self .min_value = min_value
632+ self .max_value = max_value
633+ self .elements = []
634+ self .tree = None
635+
636+ def add_entry (self , value : float ):
637+ if self .tree is not None :
638+ self .tree .add_entry (value )
639+ else :
640+ self .elements .append (value )
641+ if len (self .elements ) >= self .MAX_ELEMENTS_IN_LIST :
642+ self .create_tree ()
643+
644+ def create_tree (self ):
645+ self .tree = quantile_tree .QuantileTree (
646+ self .min_value ,
647+ self .max_value ,
648+ self .TREE_HEIGHT ,
649+ self .BRANCHING_FACTOR ,
650+ )
651+ for v in self .elements :
652+ self .tree .add_entry (v )
653+ self .elements = None
654+
655+ def merge (self , other : 'QuantileAccumulator' ):
656+ if self .tree is not None and other .tree is not None :
657+ self .tree .merge (
658+ pydp ._pydp .bytes_to_summary (other .tree .serialize ().to_bytes ()))
659+ elif self .tree is not None and other .elements is not None :
660+ for v in other .elements :
661+ self .tree .add_entry (v )
662+ elif self .elements is not None and other .tree is not None :
663+ self .create_tree ()
664+ self .tree .merge (
665+ pydp ._pydp .bytes_to_summary (other .tree .serialize ().to_bytes ()))
666+ else :
667+ self .elements .extend (other .elements )
668+ if len (self .elements ) >= self .MAX_ELEMENTS_IN_LIST :
669+ self .create_tree ()
670+
671+ def __getstate__ (self ):
672+ if self .tree is not None :
673+ return {
674+ 'tree' : self .tree .serialize ().to_bytes (),
675+ 'min_value' : self .min_value ,
676+ 'max_value' : self .max_value ,
677+ }
678+ else :
679+ return {
680+ 'elements' : self .elements ,
681+ 'min_value' : self .min_value ,
682+ 'max_value' : self .max_value ,
683+ }
684+
685+ def __setstate__ (self , state ):
686+ self .min_value = state ['min_value' ]
687+ self .max_value = state ['max_value' ]
688+ if 'tree' in state :
689+ self .tree = quantile_tree .QuantileTree (
690+ self .min_value ,
691+ self .max_value ,
692+ self .TREE_HEIGHT ,
693+ self .BRANCHING_FACTOR ,
694+ )
695+ self .tree .merge (pydp ._pydp .bytes_to_summary (state ['tree' ]))
696+ self .elements = None
697+ else :
698+ self .elements = state ['elements' ]
699+ self .tree = None
700+
701+
629702class QuantileCombiner (Combiner ):
630703 """Combiner for computing DP quantiles.
631704
@@ -637,41 +710,30 @@ class QuantileCombiner(Combiner):
637710 The accumulator is QuantileTree object serialized to string.
638711 """
639712
640- AccumulatorType = Union [bytes , List [float ]]
641-
642713 def __init__ (self , params , percentiles_to_compute : List [float ]):
643714 self ._params = params
644715 self ._percentiles = percentiles_to_compute
645716 self ._quantiles_to_compute = [p / 100 for p in percentiles_to_compute ]
646717
647- def create_accumulator (self , values ) -> AccumulatorType :
648- if _proto_serialization_disabled :
649- return values
650- tree = self ._create_empty_quantile_tree ()
718+ def create_accumulator (self , values ) -> QuantileAccumulator :
719+ acc = QuantileAccumulator (
720+ self ._params .aggregate_params .min_value ,
721+ self ._params .aggregate_params .max_value ,
722+ )
651723 for value in values :
652- tree .add_entry (value )
653- return tree . serialize (). to_bytes ()
724+ acc .add_entry (value )
725+ return acc
654726
655- def merge_accumulators (self , accumulator1 : AccumulatorType ,
656- accumulator2 : AccumulatorType ) -> AccumulatorType :
657- if _proto_serialization_disabled :
658- return accumulator1 + accumulator2 # union of lists
659-
660- tree = self ._create_empty_quantile_tree ()
661- if accumulator1 :
662- tree .merge (pydp ._pydp .bytes_to_summary (accumulator1 ))
663- if accumulator2 :
664- tree .merge (pydp ._pydp .bytes_to_summary (accumulator2 ))
665- return tree .serialize ().to_bytes ()
666-
667- def compute_metrics (self , accumulator : AccumulatorType ) -> AccumulatorType :
668- if _proto_serialization_disabled :
669- tree = self ._create_empty_quantile_tree ()
670- for value in accumulator :
671- tree .add_entry (value )
672- else :
673- tree = self ._create_empty_quantile_tree ()
674- tree .merge (pydp ._pydp .bytes_to_summary (accumulator ))
727+ def merge_accumulators (
728+ self , accumulator1 : QuantileAccumulator ,
729+ accumulator2 : QuantileAccumulator ) -> QuantileAccumulator :
730+ accumulator1 .merge (accumulator2 )
731+ return accumulator1
732+
733+ def compute_metrics (self , accumulator : QuantileAccumulator ) -> dict :
734+ if accumulator .tree is None :
735+ accumulator .create_tree ()
736+ tree = accumulator .tree
675737
676738 quantiles = dp_computations .compute_dp_quantiles (
677739 tree ,
@@ -699,16 +761,6 @@ def format_metric_name(p: float):
699761 def explain_computation (self ) -> ExplainComputationReport :
700762 return lambda : f"Computed percentiles { self ._percentiles } with (eps={ self ._params .eps } delta={ self ._params .delta } )"
701763
702- def _create_empty_quantile_tree (self ):
703- # The default tree parameters taken from
704- # https://github.com/google/differential-privacy/blob/605ec87bcbd4a536995b611132dbf4d341d2e91d/cc/algorithms/quantile-tree.h#L47
705- DEFAULT_TREE_HEIGHT = 4
706- DEFAULT_BRANCHING_FACTOR = 16
707- return quantile_tree .QuantileTree (
708- self ._params .aggregate_params .min_value ,
709- self ._params .aggregate_params .max_value , DEFAULT_TREE_HEIGHT ,
710- DEFAULT_BRANCHING_FACTOR )
711-
712764 def mechanism_spec (self ) -> budget_accounting .MechanismSpec :
713765 return self ._params .mechanism_spec
714766
0 commit comments