Skip to content

Commit 4a8ac1a

Browse files
alexojicacopybara-github
authored andcommitted
Introduce FeatureSpec and add features to AggregationParams.
PiperOrigin-RevId: 816738058
1 parent c240793 commit 4a8ac1a

22 files changed

Lines changed: 1667 additions & 717 deletions

pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/AggregationSpec.kt

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -148,58 +148,20 @@ internal fun AggregationSpec.getFeatureId(): String {
148148
}
149149
}
150150

151-
internal fun List<AggregationSpec>.metrics(): List<MetricDefinition> = buildList {
152-
for (aggregation in this@metrics) {
153-
when (aggregation) {
154-
// Count and PrivacyIdCount do not aggregate any specific value, therefore they are handled
155-
// differently.
156-
is PrivacyIdCount ->
157-
add(
158-
MetricDefinition(
159-
MetricType.PRIVACY_ID_COUNT,
160-
aggregation.budget?.toInternalBudgetPerOpSpec(),
161-
)
162-
)
163-
is Count ->
164-
add(MetricDefinition(MetricType.COUNT, aggregation.budget?.toInternalBudgetPerOpSpec()))
165-
is ValueAggregations<*> -> {
166-
for (valueAggregationSpec in aggregation.valueAggregationSpecs) {
167-
add(
168-
MetricDefinition(
169-
valueAggregationSpec.metricType,
170-
valueAggregationSpec.budget?.toInternalBudgetPerOpSpec(),
171-
)
172-
)
173-
}
174-
}
175-
is VectorAggregations<*> -> {
176-
for (vectorAggregationSpec in aggregation.vectorAggregationSpecs) {
177-
add(
178-
MetricDefinition(
179-
vectorAggregationSpec.metricType,
180-
vectorAggregationSpec.budget?.toInternalBudgetPerOpSpec(),
181-
)
182-
)
183-
}
184-
}
185-
}
186-
}
187-
}
188-
189151
internal fun List<AggregationSpec>.outputColumnNamesWithMetricTypes():
190152
List<Pair<String, MetricType>> = buildList {
191153
for (aggregation in this@outputColumnNamesWithMetricTypes) {
192154
when (aggregation) {
193-
is PrivacyIdCount -> add(aggregation.outputColumnName to MetricType.PRIVACY_ID_COUNT)
194-
is Count -> add(aggregation.outputColumnName to MetricType.COUNT)
155+
is PrivacyIdCount -> add(Pair(aggregation.outputColumnName, MetricType.PRIVACY_ID_COUNT))
156+
is Count -> add(Pair(aggregation.outputColumnName, MetricType.COUNT))
195157
is ValueAggregations<*> -> {
196158
for (valueAggregationSpec in aggregation.valueAggregationSpecs) {
197-
add(valueAggregationSpec.outputColumnName to valueAggregationSpec.metricType)
159+
add(Pair(valueAggregationSpec.outputColumnName, valueAggregationSpec.metricType))
198160
}
199161
}
200162
is VectorAggregations<*> -> {
201163
for (vectorAggregationSpec in aggregation.vectorAggregationSpecs) {
202-
add(vectorAggregationSpec.outputColumnName to vectorAggregationSpec.metricType)
164+
add(Pair(vectorAggregationSpec.outputColumnName, vectorAggregationSpec.metricType))
203165
}
204166
}
205167
}
@@ -227,3 +189,22 @@ internal fun List<AggregationSpec>.outputColumnNameToFeatureIdMap(): Map<String,
227189

228190
internal fun List<AggregationSpec>.outputColumnNames(): List<String> =
229191
outputColumnNamesWithMetricTypes().map { it.first }
192+
193+
internal fun AggregationSpec.toNonFeatureMetricDefinition(): MetricDefinition {
194+
val (metricType, budget) =
195+
when (this) {
196+
is Count -> Pair(MetricType.COUNT, this.budget)
197+
is PrivacyIdCount -> Pair(MetricType.PRIVACY_ID_COUNT, this.budget)
198+
else ->
199+
throw IllegalArgumentException("Unsupported AggregationSpec type for non feature metrics")
200+
}
201+
return MetricDefinition(metricType, budget?.toInternalBudgetPerOpSpec())
202+
}
203+
204+
internal fun ValueAggregationSpec.toMetricDefinition(): MetricDefinition {
205+
return MetricDefinition(this.metricType, this.budget?.toInternalBudgetPerOpSpec())
206+
}
207+
208+
internal fun VectorAggregationSpec.toMetricDefinition(): MetricDefinition {
209+
return MetricDefinition(this.metricType, this.budget?.toInternalBudgetPerOpSpec())
210+
}

pipelinedp4j/main/com/google/privacy/differentialprivacy/pipelinedp4j/api/Query.kt

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@ import com.google.privacy.differentialprivacy.pipelinedp4j.core.DpEngine
2323
import com.google.privacy.differentialprivacy.pipelinedp4j.core.DpEngineBudgetSpec
2424
import com.google.privacy.differentialprivacy.pipelinedp4j.core.Encoder
2525
import com.google.privacy.differentialprivacy.pipelinedp4j.core.EncoderFactory
26+
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FeatureSpec
2627
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FeatureValuesExtractor
2728
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkCollection
2829
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkTable
2930
import com.google.privacy.differentialprivacy.pipelinedp4j.core.MetricType
31+
import com.google.privacy.differentialprivacy.pipelinedp4j.core.ScalarFeatureSpec
3032
import com.google.privacy.differentialprivacy.pipelinedp4j.core.SelectPartitionsParams
33+
import com.google.privacy.differentialprivacy.pipelinedp4j.core.VectorFeatureSpec
3134
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.DpAggregates
3235
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.PerFeature
3336
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.copy
@@ -494,22 +497,53 @@ protected constructor(
494497
valueAggregations: ValueAggregations<*>?,
495498
vectorAggregations: VectorAggregations<*>?,
496499
): AggregationParams {
497-
val valueContributionBounds = valueAggregations?.contributionBounds
498-
val vectorContributionBounds = vectorAggregations?.vectorContributionBounds
500+
val nonFeatureMetrics =
501+
aggregationSpecs
502+
.filter { it is Count || it is PrivacyIdCount }
503+
.map { it.toNonFeatureMetricDefinition() }
504+
val features =
505+
buildList<FeatureSpec> {
506+
if (valueAggregations != null) {
507+
val valueContributionBounds = valueAggregations.contributionBounds
508+
add(
509+
ScalarFeatureSpec(
510+
featureId = valueAggregations.getFeatureId(),
511+
metrics =
512+
valueAggregations.valueAggregationSpecs
513+
.map { it.toMetricDefinition() }
514+
.toImmutableList(),
515+
minValue = valueContributionBounds.valueBounds?.minValue,
516+
maxValue = valueContributionBounds.valueBounds?.maxValue,
517+
minTotalValue = valueContributionBounds.totalValueBounds?.minValue,
518+
maxTotalValue = valueContributionBounds.totalValueBounds?.maxValue,
519+
)
520+
)
521+
}
522+
if (vectorAggregations != null) {
523+
val vectorContributionBounds = vectorAggregations.vectorContributionBounds
524+
add(
525+
VectorFeatureSpec(
526+
featureId = vectorAggregations.getFeatureId(),
527+
metrics =
528+
vectorAggregations.vectorAggregationSpecs
529+
.map { it.toMetricDefinition() }
530+
.toImmutableList(),
531+
vectorSize = vectorAggregations.vectorSize,
532+
normKind = vectorContributionBounds.maxVectorTotalNorm.normKind.toInternalNormKind(),
533+
vectorMaxTotalNorm = vectorContributionBounds.maxVectorTotalNorm.value,
534+
)
535+
)
536+
}
537+
}
538+
499539
return AggregationParams(
500-
metrics = ImmutableList.copyOf(aggregationSpecs.metrics()),
540+
nonFeatureMetrics = nonFeatureMetrics.toImmutableList(),
541+
features = features.toImmutableList(),
501542
noiseKind =
502543
checkNotNull(noiseKind) { "noiseKind cannot be null if there are aggregations." }
503544
.toInternalNoiseKind(),
504545
maxPartitionsContributed = contributionBoundingLevel.getMaxPartitionsContributed(),
505546
maxContributionsPerPartition = contributionBoundingLevel.getMaxContributionsPerPartition(),
506-
minValue = valueContributionBounds?.valueBounds?.minValue,
507-
maxValue = valueContributionBounds?.valueBounds?.maxValue,
508-
minTotalValue = valueContributionBounds?.totalValueBounds?.minValue,
509-
maxTotalValue = valueContributionBounds?.totalValueBounds?.maxValue,
510-
vectorNormKind = vectorContributionBounds?.maxVectorTotalNorm?.normKind?.toInternalNormKind(),
511-
vectorMaxTotalNorm = vectorContributionBounds?.maxVectorTotalNorm?.value,
512-
vectorSize = vectorAggregations?.vectorSize,
513547
partitionSelectionBudget = groupsType.getBudget()?.toInternalBudgetPerOpSpec(),
514548
preThreshold = groupsType.getPreThreshold(),
515549
contributionBoundingLevel = contributionBoundingLevel.toInternalContributionBoundingLevel(),
@@ -534,3 +568,5 @@ protected constructor(
534568
}
535569
}
536570
}
571+
572+
private fun <T : Any> Iterable<T>.toImmutableList(): ImmutableList<T> = ImmutableList.copyOf(this)

0 commit comments

Comments
 (0)