Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions src/partialstats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
partialstats — distributed statistical aggregation via partial results.
"""

from .partials import Partial, SumPartial, SumOfSquaresPartial
from .partials import MeanPartial, VariancePartial
from .combiners import Combiner

__all__ = [
"Partial",
"Combiner",
"SumPartial",
"SumOfSquaresPartial",
]
__all__ = ["Combiner", "MeanPartial", "VariancePartial"]
11 changes: 9 additions & 2 deletions src/partialstats/combiners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from .core import Combiner
from .core import Combiner, CombinerProtocol, SumCombiner
from .statistical import mean_combiner, variance_combiner, std_combiner

__all__ = ["Combiner", "mean_combiner", "variance_combiner", "std_combiner"]
__all__ = [
"Combiner",
"CombinerProtocol",
"SumCombiner",
"mean_combiner",
"variance_combiner",
"std_combiner",
]
42 changes: 35 additions & 7 deletions src/partialstats/combiners/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import TypeVar, Generic, Iterable, Callable
from typing import TypeVar, Generic, Iterable, Callable, Protocol, final
from dataclasses import dataclass
from functools import reduce

from ..partials.protocol import S
from partialstats.partials.protocol import AddsProtocol

S = TypeVar("S")
S_contra = TypeVar("S_contra", contravariant=True)
R = TypeVar("R")
R_co = TypeVar("R_co", covariant=True)


def build_combine_function(
Expand All @@ -22,14 +25,19 @@ def build_combine_function(
A function for calculating the final desired result
"""

def combine(partials: Iterable[S]):
def combine(partials: Iterable[S]) -> R:
first, *rest = partials
return finalise(reduce(aggregate, rest, first))

return combine


@dataclass
class CombinerProtocol(Protocol[S_contra, R_co]):
def combine(self, partials: Iterable[S_contra]) -> R_co: ...


@final
@dataclass(frozen=True)
class Combiner(Generic[S, R]):
"""
Runs on the aggregator to combine partial results from all nodes into
Expand All @@ -41,11 +49,31 @@ class Combiner(Generic[S, R]):
"""

finalise: Callable[[S], R]
aggregate: Callable[[S, S], S] = lambda a, b: a + b
aggregate: Callable[[S, S], S]

def combine(self, partials: Iterable[S]) -> R:
"""
Folds the partial results together using `aggregate`, then calls `finalise`.
"""
first, *rest = partials
return self.finalise(reduce(self.aggregate, rest, first))
return build_combine_function(self.aggregate, self.finalise)(partials)


Adds = TypeVar("Adds", bound=AddsProtocol)


@final
@dataclass(frozen=True)
class SumCombiner(Generic[Adds, R]):
"""
Runs on the aggregator to combine partial results from all nodes into
the final statistic.

Type parameters:
Adds: the type of the partial results produced by a PartialReducer. Must implement the __add__ special method.
R: the type of the final result
"""

finalise: Callable[[Adds], R]

def combine(self, partials: Iterable[Adds]) -> R:
return build_combine_function(lambda a, b: a + b, self.finalise)(partials)
1 change: 1 addition & 0 deletions src/partialstats/combiners/scalar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .core import Combiner

sum_combiner = Combiner[float, float](
aggregate=lambda a, b: a + b,
finalise=lambda x: x,
)

Expand Down
33 changes: 19 additions & 14 deletions src/partialstats/combiners/statistical.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
from math import sqrt

from ..partials import SumPartial, SumOfSquaresPartial
from .core import Combiner
from ..partials.protocol import (
CountPartialProtocol,
SumPartialProtocol,
MeanPartialProtocol,
VariancePartialProtocol,
)
from .core import SumCombiner

count_combiner = Combiner[SumPartial, float](
count_combiner = SumCombiner[CountPartialProtocol, int](
finalise=lambda x: x.count,
)
"""Combines SumPartials into a global count"""
"""Combines count partial results into a global count"""

sum_combiner = Combiner[SumPartial, float](
sum_combiner = SumCombiner[SumPartialProtocol, float](
finalise=lambda x: x.sum,
)
"""Combines SumPartials into a global sum"""
"""Combines sum partial results into a global sum"""

mean_combiner = Combiner[SumPartial, float](
mean_combiner = SumCombiner[MeanPartialProtocol, float](
finalise=lambda x: x.sum / x.count,
)
"""Combines SumPartials into a global mean."""
"""Combines partial results into a global mean."""

variance_combiner = Combiner[SumOfSquaresPartial, float](
finalise=lambda x: x.sumsq / x.count - (x.sum / x.count) ** 2,
variance_combiner = SumCombiner[VariancePartialProtocol, float](
finalise=lambda x: x.sum_of_squares / x.count - (x.sum / x.count) ** 2,
)
"""Combines SumOfSquaresPartials into a global population variance."""
"""Combines partial results into a global population variance."""

std_combiner = Combiner[SumOfSquaresPartial, float](
finalise=lambda x: sqrt(x.sumsq / x.count - (x.sum / x.count) ** 2),
std_combiner = SumCombiner[VariancePartialProtocol, float](
finalise=lambda x: sqrt(x.sum_of_squares / x.count - (x.sum / x.count) ** 2),
)
"""Combines SumOfSquaresPartials into a global population standard deviation."""
"""Combines partial results into a global population standard deviation."""
7 changes: 3 additions & 4 deletions src/partialstats/partials/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .protocol import Partial
from .sum import SumPartial
from .sum_of_squares import SumOfSquaresPartial
from .mean import MeanPartial
from .variance import VariancePartial

__all__ = ["Partial", "SumPartial", "SumOfSquaresPartial"]
__all__ = ["MeanPartial", "VariancePartial"]
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from typing import Self
from dataclasses import dataclass
from .protocol import Partial
from .protocol import AddsProtocol, MeanPartialProtocol


@dataclass
class SumPartial(Partial):
class MeanPartial(MeanPartialProtocol, AddsProtocol):
"""Partial result carrying a running sum and count."""

sum: float
count: int

def __add__(self, other: Self) -> Self:
return type(self)(self.sum + other.sum, self.count + other.count)

@classmethod
def identity(cls) -> Self:
return cls(0, 0)
26 changes: 22 additions & 4 deletions src/partialstats/partials/protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
from typing import Protocol, TypeVar
from typing import Protocol, Self

S = TypeVar("S")

class AddsProtocol(Protocol):
def __add__(self, other: Self) -> Self: ...

class Partial(Protocol):
def __add__(self: S, other: S) -> S: ...

class CountPartialProtocol(AddsProtocol, Protocol):
count: int


class SumPartialProtocol(AddsProtocol, Protocol):
sum: float


class MeanPartialProtocol(CountPartialProtocol, SumPartialProtocol, Protocol): ...


class SumOfSquaresPartialProtocol(AddsProtocol, Protocol):
sum_of_squares: float


class VariancePartialProtocol(
MeanPartialProtocol, SumOfSquaresPartialProtocol, Protocol
): ...
21 changes: 0 additions & 21 deletions src/partialstats/partials/sum_of_squares.py

This file was deleted.

19 changes: 19 additions & 0 deletions src/partialstats/partials/variance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Self
from dataclasses import dataclass
from .protocol import AddsProtocol, VariancePartialProtocol


@dataclass
class VariancePartial(VariancePartialProtocol, AddsProtocol):
"""Partial result carrying a running sum, sum of squares, and count."""

sum: float
sum_of_squares: float
count: int

def __add__(self, other: Self) -> Self:
return type(self)(
self.sum + other.sum,
self.sum_of_squares + other.sum_of_squares,
self.count + other.count,
)
21 changes: 12 additions & 9 deletions src/partialstats/reference/reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from dataclasses import dataclass
from functools import reduce

from ..partials import SumPartial, SumOfSquaresPartial
from ..partials.protocol import S
from ..combiners import Combiner
from ..partials import MeanPartial, VariancePartial
from ..combiners import CombinerProtocol

T = TypeVar("T")
S = TypeVar("S")
R = TypeVar("R")


Expand All @@ -21,7 +21,7 @@ class PartialReducer(Generic[T, S]):
"""

apply: Callable[[T], S]
merge: Callable[[S, S], S] = lambda a, b: a + b
merge: Callable[[S, S], S]

def reduce(self, rows: Iterable[T]) -> S:
"""
Expand All @@ -34,17 +34,20 @@ def reduce(self, rows: Iterable[T]) -> S:
# Reducer implementations for reference

count_reducer = PartialReducer[object, int](
merge=lambda a, b: a + b,
apply=lambda _: 1,
)
"""Counts the number of rows in each partition."""

sum_reducer = PartialReducer[float, SumPartial](
apply=lambda x: SumPartial(sum=x, count=1),
sum_reducer = PartialReducer[float, MeanPartial](
merge=lambda a, b: a + b,
apply=lambda x: MeanPartial(sum=x, count=1),
)
"""Accumulates the sum and count of values — sufficient to compute mean."""

sum_of_squares_reducer = PartialReducer[float, SumOfSquaresPartial](
apply=lambda x: SumOfSquaresPartial(sum=x, sumsq=x * x, count=1),
sum_of_squares_reducer = PartialReducer[float, VariancePartial](
merge=lambda a, b: a + b,
apply=lambda x: VariancePartial(sum=x, sum_of_squares=x * x, count=1),
)
"""Accumulates sum, sum of squares, and count — sufficient to compute variance and std dev."""

Expand Down Expand Up @@ -75,7 +78,7 @@ class DistributedStat(Generic[T, S, R]):
"""

reducer: PartialReducer[T, S]
combiner: Combiner[S, R]
combiner: CombinerProtocol[S, R]

def compute(self, partitions: Iterable[Iterable[T]]) -> R:
"""
Expand Down
8 changes: 8 additions & 0 deletions src/partialstats/stat_aggregators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .stat_aggregators import (
CountAggregator,
SumAggregator,
MeanAggregator,
VarianceAggregator,
)

__all__ = ["CountAggregator", "SumAggregator", "MeanAggregator", "VarianceAggregator"]
Loading