@@ -28,6 +28,7 @@ import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkCollect
2828import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkTable
2929import com.google.privacy.differentialprivacy.pipelinedp4j.core.MetricType
3030import com.google.privacy.differentialprivacy.pipelinedp4j.core.SelectPartitionsParams
31+ import com.google.privacy.differentialprivacy.pipelinedp4j.core.budget.BudgetAllocationDetails
3132import com.google.privacy.differentialprivacy.pipelinedp4j.proto.DpAggregates
3233import com.google.privacy.differentialprivacy.pipelinedp4j.proto.PerFeature
3334import com.google.privacy.differentialprivacy.pipelinedp4j.proto.copy
@@ -71,7 +72,17 @@ protected constructor(
7172 validate()
7273 }
7374
74- protected fun runWithDpEngine (testMode : TestMode ): FrameworkTable <GroupKeysT , DpAggregates > {
75+ /* *
76+ * The result of running the DP engine.
77+ *
78+ * Contains the aggregated metrics and budget allocation details.
79+ */
80+ protected data class DpEngineResult <GroupKeysT >(
81+ val aggregationResults : FrameworkTable <GroupKeysT , DpAggregates >,
82+ val budgetAllocationDetails : List <BudgetAllocationDetails >,
83+ )
84+
85+ protected fun runWithDpEngine (testMode : TestMode ): DpEngineResult <GroupKeysT > {
7586 val dpEngine =
7687 DpEngine .create(
7788 encoderFactory,
@@ -84,13 +95,17 @@ protected constructor(
8495 val extractors =
8596 createDataExtractors(valueExtractor = null , vectorExtractor = null , featureId = null )
8697 val result = dpEngine.selectPartitions(data, createSelectPartitionsParams(), extractors)
87- dpEngine.done()
88-
89- return result.mapToTable(
90- " Add empty DpAggregates" ,
91- groupKeyEncoder,
92- encoderFactory.protos(DpAggregates ::class ),
93- { it to DpAggregates .getDefaultInstance() },
98+ val budgetAllocationDetails = dpEngine.done()
99+
100+ return DpEngineResult (
101+ aggregationResults =
102+ result.mapToTable(
103+ " Add empty DpAggregates" ,
104+ groupKeyEncoder,
105+ encoderFactory.protos(DpAggregates ::class ),
106+ { it to DpAggregates .getDefaultInstance() },
107+ ),
108+ budgetAllocationDetails = budgetAllocationDetails,
94109 )
95110 }
96111
@@ -133,42 +148,48 @@ protected constructor(
133148 aggregateWithDpEngine(dpEngine, featureAggregation, listOf (featureAggregation), partitions)
134149 aggResults.add(result)
135150 }
136- dpEngine.done()
151+ val budgetAllocationDetails = dpEngine.done()
137152
138153 val featureIdPerRun =
139154 if (valueAndVectorAggs.isEmpty()) {
140155 listOf (null )
141156 } else {
142157 valueAndVectorAggs.map { it.getFeatureId() }
143158 }
144- return aggResults
145- .zip(featureIdPerRun)
146- .map { (table, featureId) ->
147- table.mapValues(" TagWithFeatureId" , encoderFactory.protos(DpAggregates ::class )) { _, agg ->
148- if (featureId == null ) {
149- agg
150- } else {
151- val perFeature = constructPerFeature(agg, featureId)
152- dpAggregates {
153- count = agg.count
154- privacyIdCount = agg.privacyIdCount
155- this .perFeature + = perFeature
159+ val aggregationResults =
160+ aggResults
161+ .zip(featureIdPerRun)
162+ .map { (table, featureId) ->
163+ table.mapValues(" TagWithFeatureId" , encoderFactory.protos(DpAggregates ::class )) { _, agg
164+ ->
165+ if (featureId == null ) {
166+ agg
167+ } else {
168+ val perFeature = constructPerFeature(agg, featureId)
169+ dpAggregates {
170+ count = agg.count
171+ privacyIdCount = agg.privacyIdCount
172+ this .perFeature + = perFeature
173+ }
156174 }
157175 }
158176 }
159- }
160- .reduce {
161- acc: FrameworkTable <GroupKeysT , DpAggregates >,
162- table: FrameworkTable <GroupKeysT , DpAggregates > ->
163- acc.flattenWith(" FlattenResultsFromMultipleRuns" , table)
164- }
165- .groupAndCombineValues(" MergeDpAggregates" ) { acc, dpAggregatesFromSingleRun ->
166- acc.copy {
167- count + = dpAggregatesFromSingleRun.count
168- privacyIdCount + = dpAggregatesFromSingleRun.privacyIdCount
169- perFeature + = dpAggregatesFromSingleRun.perFeatureList
177+ .reduce {
178+ acc: FrameworkTable <GroupKeysT , DpAggregates >,
179+ table: FrameworkTable <GroupKeysT , DpAggregates > ->
180+ acc.flattenWith(" FlattenResultsFromMultipleRuns" , table)
170181 }
171- }
182+ .groupAndCombineValues(" MergeDpAggregates" ) { acc, dpAggregatesFromSingleRun ->
183+ acc.copy {
184+ count + = dpAggregatesFromSingleRun.count
185+ privacyIdCount + = dpAggregatesFromSingleRun.privacyIdCount
186+ perFeature + = dpAggregatesFromSingleRun.perFeatureList
187+ }
188+ }
189+ return DpEngineResult (
190+ aggregationResults = aggregationResults,
191+ budgetAllocationDetails = budgetAllocationDetails,
192+ )
172193 }
173194
174195 private fun validate () {
@@ -277,8 +298,9 @@ protected constructor(
277298 private fun requireDistinctValueExtractors (aggregationsPerValue : List <ValueAggregations <* >>) {
278299 val valueExtractorCounts = aggregationsPerValue.groupingBy { it.valueExtractor }.eachCount()
279300 val duplicates = valueExtractorCounts.filter { it.value > 1 }.keys
280- val valueAggregationsWithDuplicates =
281- aggregationsPerValue.filter { it.valueExtractor in duplicates }
301+ val valueAggregationsWithDuplicates = aggregationsPerValue.filter {
302+ it.valueExtractor in duplicates
303+ }
282304 require(duplicates.isEmpty()) {
283305 " There are the same (object reference equality) value extractors used in different aggregateValue() calls. Please merge them into one call." +
284306 " \n Value aggregations with duplicate value extractors:\n ${
@@ -290,8 +312,9 @@ protected constructor(
290312 private fun requireDistinctVectorExtractors (aggregationsPerVector : List <VectorAggregations <* >>) {
291313 val vectorExtractorCounts = aggregationsPerVector.groupingBy { it.vectorExtractor }.eachCount()
292314 val duplicates = vectorExtractorCounts.filter { it.value > 1 }.keys
293- val vectorAggregationsWithDuplicates =
294- aggregationsPerVector.filter { it.vectorExtractor in duplicates }
315+ val vectorAggregationsWithDuplicates = aggregationsPerVector.filter {
316+ it.vectorExtractor in duplicates
317+ }
295318 require(duplicates.isEmpty()) {
296319 " There are the same (object reference equality) vector extractors used in different aggregateVector() calls. Please merge them into one call." +
297320 " \n Vector aggregations with duplicate vector extractors:\n ${
0 commit comments