Skip to content

Commit c44b992

Browse files
RamSawcopybara-github
authored andcommitted
Create DPEngineResult data class in API package to pass both aggregations and allocated budgets
PiperOrigin-RevId: 839205425
1 parent 69e3ec8 commit c44b992

14 files changed

Lines changed: 264 additions & 67 deletions

File tree

pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ kt_jvm_library(
3131
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:dp_functions_params",
3232
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:encoders",
3333
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core:framework_collections",
34+
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core/budget:budget_allocation_details",
3435
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core/budget:budget_spec",
3536
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/local:local_collections",
3637
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/local:local_encoders",

pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/BeamApi.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,16 @@ internal constructor(
342342
noiseKind,
343343
) {
344344
override fun run(testMode: TestMode): BeamPCollection<QueryPerGroupResult<GroupKeysT>> {
345-
val beamResult =
346-
(runWithDpEngine(testMode) as BeamFrameworkTable<GroupKeysT, DpAggregates>).data
345+
val dpEngineResult: DpEngineResult<GroupKeysT> = runWithDpEngine(testMode)
346+
val beamAggregationResults =
347+
(dpEngineResult.aggregationResults as BeamFrameworkTable<GroupKeysT, DpAggregates>).data
347348
val coder = QueryPerGroupResultCoder(groupKeyEncoder.coder)
348349
val mapToResultFn =
349350
createConvertDpAggregatesToQueryPerGroupResultFn(
350351
aggregations.outputColumnNamesWithMetricTypes(),
351352
aggregations.outputColumnNameToFeatureIdMap(),
352353
)
353-
return beamResult
354+
return beamAggregationResults
354355
.apply(MapElements.into(coder.encodedTypeDescriptor).via(SerializableFunction(mapToResultFn)))
355356
.setCoder(coder)
356357
}

pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/LocalApi.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,15 @@ internal constructor(
228228
noiseKind,
229229
) {
230230
override fun run(testMode: TestMode): Sequence<QueryPerGroupResult<GroupKeysT>> {
231-
val localResult =
232-
(runWithDpEngine(testMode) as LocalFrameworkTable<GroupKeysT, DpAggregates>).data
231+
val dpEngineResult: DpEngineResult<GroupKeysT> = runWithDpEngine(testMode)
232+
val localAggregationResults =
233+
(dpEngineResult.aggregationResults as LocalFrameworkTable<GroupKeysT, DpAggregates>).data
233234
val mapToResultFn =
234235
createConvertDpAggregatesToQueryPerGroupResultFn(
235236
aggregations.outputColumnNamesWithMetricTypes(),
236237
aggregations.outputColumnNameToFeatureIdMap(),
237238
)
238-
return localResult.map(mapToResultFn)
239+
return localAggregationResults.map(mapToResultFn)
239240
}
240241

241242
private fun createConvertDpAggregatesToQueryPerGroupResultFn(

pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/Query.kt

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkCollect
2828
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkTable
2929
import com.google.privacy.differentialprivacy.pipelinedp4j.core.MetricType
3030
import com.google.privacy.differentialprivacy.pipelinedp4j.core.SelectPartitionsParams
31+
import com.google.privacy.differentialprivacy.pipelinedp4j.core.budget.BudgetAllocationDetails
3132
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.DpAggregates
3233
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.PerFeature
3334
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.copy
@@ -71,7 +72,17 @@ protected constructor(
7172
validate()
7273
}
7374

74-
protected fun runWithDpEngine(testMode: TestMode): FrameworkTable<GroupKeysT, DpAggregates> {
75+
/**
76+
* The result of running the DP engine.
77+
*
78+
* Contains the aggregated metrics and budget allocation details.
79+
*/
80+
protected data class DpEngineResult<GroupKeysT>(
81+
val aggregationResults: FrameworkTable<GroupKeysT, DpAggregates>,
82+
val budgetAllocationDetails: List<BudgetAllocationDetails>,
83+
)
84+
85+
protected fun runWithDpEngine(testMode: TestMode): DpEngineResult<GroupKeysT> {
7586
val dpEngine =
7687
DpEngine.create(
7788
encoderFactory,
@@ -84,13 +95,17 @@ protected constructor(
8495
val extractors =
8596
createDataExtractors(valueExtractor = null, vectorExtractor = null, featureId = null)
8697
val result = dpEngine.selectPartitions(data, createSelectPartitionsParams(), extractors)
87-
dpEngine.done()
88-
89-
return result.mapToTable(
90-
"Add empty DpAggregates",
91-
groupKeyEncoder,
92-
encoderFactory.protos(DpAggregates::class),
93-
{ it to DpAggregates.getDefaultInstance() },
98+
val budgetAllocationDetails = dpEngine.done()
99+
100+
return DpEngineResult(
101+
aggregationResults =
102+
result.mapToTable(
103+
"Add empty DpAggregates",
104+
groupKeyEncoder,
105+
encoderFactory.protos(DpAggregates::class),
106+
{ it to DpAggregates.getDefaultInstance() },
107+
),
108+
budgetAllocationDetails = budgetAllocationDetails,
94109
)
95110
}
96111

@@ -133,42 +148,48 @@ protected constructor(
133148
aggregateWithDpEngine(dpEngine, featureAggregation, listOf(featureAggregation), partitions)
134149
aggResults.add(result)
135150
}
136-
dpEngine.done()
151+
val budgetAllocationDetails = dpEngine.done()
137152

138153
val featureIdPerRun =
139154
if (valueAndVectorAggs.isEmpty()) {
140155
listOf(null)
141156
} else {
142157
valueAndVectorAggs.map { it.getFeatureId() }
143158
}
144-
return aggResults
145-
.zip(featureIdPerRun)
146-
.map { (table, featureId) ->
147-
table.mapValues("TagWithFeatureId", encoderFactory.protos(DpAggregates::class)) { _, agg ->
148-
if (featureId == null) {
149-
agg
150-
} else {
151-
val perFeature = constructPerFeature(agg, featureId)
152-
dpAggregates {
153-
count = agg.count
154-
privacyIdCount = agg.privacyIdCount
155-
this.perFeature += perFeature
159+
val aggregationResults =
160+
aggResults
161+
.zip(featureIdPerRun)
162+
.map { (table, featureId) ->
163+
table.mapValues("TagWithFeatureId", encoderFactory.protos(DpAggregates::class)) { _, agg
164+
->
165+
if (featureId == null) {
166+
agg
167+
} else {
168+
val perFeature = constructPerFeature(agg, featureId)
169+
dpAggregates {
170+
count = agg.count
171+
privacyIdCount = agg.privacyIdCount
172+
this.perFeature += perFeature
173+
}
156174
}
157175
}
158176
}
159-
}
160-
.reduce {
161-
acc: FrameworkTable<GroupKeysT, DpAggregates>,
162-
table: FrameworkTable<GroupKeysT, DpAggregates> ->
163-
acc.flattenWith("FlattenResultsFromMultipleRuns", table)
164-
}
165-
.groupAndCombineValues("MergeDpAggregates") { acc, dpAggregatesFromSingleRun ->
166-
acc.copy {
167-
count += dpAggregatesFromSingleRun.count
168-
privacyIdCount += dpAggregatesFromSingleRun.privacyIdCount
169-
perFeature += dpAggregatesFromSingleRun.perFeatureList
177+
.reduce {
178+
acc: FrameworkTable<GroupKeysT, DpAggregates>,
179+
table: FrameworkTable<GroupKeysT, DpAggregates> ->
180+
acc.flattenWith("FlattenResultsFromMultipleRuns", table)
170181
}
171-
}
182+
.groupAndCombineValues("MergeDpAggregates") { acc, dpAggregatesFromSingleRun ->
183+
acc.copy {
184+
count += dpAggregatesFromSingleRun.count
185+
privacyIdCount += dpAggregatesFromSingleRun.privacyIdCount
186+
perFeature += dpAggregatesFromSingleRun.perFeatureList
187+
}
188+
}
189+
return DpEngineResult(
190+
aggregationResults = aggregationResults,
191+
budgetAllocationDetails = budgetAllocationDetails,
192+
)
172193
}
173194

174195
private fun validate() {
@@ -277,8 +298,9 @@ protected constructor(
277298
private fun requireDistinctValueExtractors(aggregationsPerValue: List<ValueAggregations<*>>) {
278299
val valueExtractorCounts = aggregationsPerValue.groupingBy { it.valueExtractor }.eachCount()
279300
val duplicates = valueExtractorCounts.filter { it.value > 1 }.keys
280-
val valueAggregationsWithDuplicates =
281-
aggregationsPerValue.filter { it.valueExtractor in duplicates }
301+
val valueAggregationsWithDuplicates = aggregationsPerValue.filter {
302+
it.valueExtractor in duplicates
303+
}
282304
require(duplicates.isEmpty()) {
283305
"There are the same (object reference equality) value extractors used in different aggregateValue() calls. Please merge them into one call." +
284306
"\nValue aggregations with duplicate value extractors:\n${
@@ -290,8 +312,9 @@ protected constructor(
290312
private fun requireDistinctVectorExtractors(aggregationsPerVector: List<VectorAggregations<*>>) {
291313
val vectorExtractorCounts = aggregationsPerVector.groupingBy { it.vectorExtractor }.eachCount()
292314
val duplicates = vectorExtractorCounts.filter { it.value > 1 }.keys
293-
val vectorAggregationsWithDuplicates =
294-
aggregationsPerVector.filter { it.vectorExtractor in duplicates }
315+
val vectorAggregationsWithDuplicates = aggregationsPerVector.filter {
316+
it.vectorExtractor in duplicates
317+
}
295318
require(duplicates.isEmpty()) {
296319
"There are the same (object reference equality) vector extractors used in different aggregateVector() calls. Please merge them into one call." +
297320
"\nVector aggregations with duplicate vector extractors:\n${

pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/SparkApi.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,9 @@ internal constructor(
329329
noiseKind,
330330
) {
331331
override fun run(testMode: TestMode): SparkDataset<QueryPerGroupResult<GroupKeysT>> {
332-
val sparkResult =
333-
(runWithDpEngine(testMode) as SparkFrameworkTable<GroupKeysT, DpAggregates>).data
332+
val dpEngineResult: DpEngineResult<GroupKeysT> = runWithDpEngine(testMode)
333+
val sparkAggregationResults =
334+
(dpEngineResult.aggregationResults as SparkFrameworkTable<GroupKeysT, DpAggregates>).data
334335
@Suppress("UNCHECKED_CAST")
335336
val queryPerGroupResultEncoder =
336337
Encoders.kryo(QueryPerGroupResult::class.java) as Encoder<QueryPerGroupResult<GroupKeysT>>
@@ -339,7 +340,7 @@ internal constructor(
339340
aggregations.outputColumnNamesWithMetricTypes(),
340341
aggregations.outputColumnNameToFeatureIdMap(),
341342
)
342-
return sparkResult.map(MapFunction(mapToResultFn), queryPerGroupResultEncoder)
343+
return sparkAggregationResults.map(MapFunction(mapToResultFn), queryPerGroupResultEncoder)
343344
}
344345

345346
private fun createConvertDpAggregatesToQueryPerGroupResultFn(

pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ kt_jvm_library(
115115
":framework_collections",
116116
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core/budget:allocated_budget",
117117
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core/budget:budget_accountant",
118+
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core/budget:budget_allocation_details",
118119
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/core/budget:budget_spec",
119120
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/dplibrary:noise_factories",
120121
"//main/com/google/privacy/differentialprivacy/pipelinedp4j/dplibrary:pre_aggregation_partition_selection_factory",

pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/DpEngine.kt

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.privacy.differentialprivacy.pipelinedp4j.core
1818

19+
import com.google.errorprone.annotations.CanIgnoreReturnValue
1920
import com.google.privacy.differentialprivacy.Noise
2021
import com.google.privacy.differentialprivacy.pipelinedp4j.core.MetricType.COUNT
2122
import com.google.privacy.differentialprivacy.pipelinedp4j.core.MetricType.MEAN
@@ -35,6 +36,7 @@ import com.google.privacy.differentialprivacy.pipelinedp4j.core.budget.BudgetAcc
3536
import com.google.privacy.differentialprivacy.pipelinedp4j.core.budget.BudgetAccountantFactory
3637
import com.google.privacy.differentialprivacy.pipelinedp4j.core.budget.BudgetAccountingStrategy
3738
import com.google.privacy.differentialprivacy.pipelinedp4j.core.budget.BudgetAccountingStrategy.NAIVE
39+
import com.google.privacy.differentialprivacy.pipelinedp4j.core.budget.BudgetAllocationDetails
3840
import com.google.privacy.differentialprivacy.pipelinedp4j.core.budget.BudgetPerOpSpec
3941
import com.google.privacy.differentialprivacy.pipelinedp4j.core.budget.BudgetRequest
4042
import com.google.privacy.differentialprivacy.pipelinedp4j.core.budget.RelativeBudgetPerOpSpec
@@ -204,13 +206,26 @@ internal constructor(
204206
}
205207

206208
/**
207-
* Allocates privacy budgets to the metrics whose computation has been requested by calling
208-
* [aggregate]. This method must be called once per [DpEngine] instance.
209+
* Allocates privacy budgets to privacy-preserving operations in [aggregate] and
210+
* [selectPartitions] calls.
211+
*
212+
* Privacy-preserving operations are various aggregation metrics, like COUNT or SUM, and partition
213+
* selection. There might be multiple privacy-preserving operations in a single [DpEngine]
214+
* instance.
215+
*
216+
* This method must be called once per [DpEngine] instance.
217+
*
218+
* @return a list of [BudgetAllocationDetails] for each privacy-preserving operation. This reports
219+
* the actual budgets used during computation, which may include budgets for operations that
220+
* were not directly requested (e.g., for a MEAN aggregation, budget details for both SUM and
221+
* COUNT will be returned).
222+
* @throws IllegalStateException if [done] has already been called on this instance.
209223
*/
210-
fun done() {
224+
@CanIgnoreReturnValue
225+
fun done(): List<BudgetAllocationDetails> {
211226
throwIfDoneWasCalled()
212227
doneCalled = true
213-
budgetAccountant.allocateBudgets()
228+
return budgetAccountant.allocateBudgets()
214229
}
215230

216231
private fun throwIfDoneWasCalled() {

pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/core/budget/BUILD.bazel

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ package(
2525
],
2626
)
2727

28+
kt_jvm_library(
29+
name = "budget_allocation_details",
30+
srcs = ["BudgetAllocationDetails.kt"],
31+
)
32+
2833
kt_jvm_library(
2934
name = "budget_spec",
3035
srcs = ["BudgetSpec.kt"],
@@ -43,6 +48,8 @@ kt_jvm_library(
4348
srcs = ["BudgetAccountant.kt"],
4449
deps = [
4550
":allocated_budget",
51+
":budget_allocation_details",
4652
":budget_spec",
53+
"@maven//:com_google_errorprone_error_prone_annotations",
4754
],
4855
)

0 commit comments

Comments
 (0)