|
| 1 | +"""Core unit assignment containers.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from collections.abc import Hashable, Iterable, Mapping |
| 6 | +from dataclasses import dataclass |
| 7 | + |
| 8 | +import pandas as pd |
| 9 | + |
| 10 | + |
| 11 | +def _series(values: pd.Series | Iterable[Hashable], name: str) -> pd.Series: |
| 12 | + if isinstance(values, pd.Series): |
| 13 | + return values.rename(name) |
| 14 | + return pd.Series(list(values), name=name) |
| 15 | + |
| 16 | + |
| 17 | +@dataclass(frozen=True) |
| 18 | +class UnitPartition: |
| 19 | + """A policy-unit assignment with exactly one unit per person.""" |
| 20 | + |
| 21 | + unit_type: str |
| 22 | + person_id: pd.Series |
| 23 | + unit_id: pd.Series |
| 24 | + role: pd.Series | None = None |
| 25 | + source: str | None = None |
| 26 | + |
| 27 | + def __post_init__(self) -> None: |
| 28 | + person_id = _series(self.person_id, "person_id") |
| 29 | + unit_id = _series(self.unit_id, "unit_id") |
| 30 | + |
| 31 | + if len(person_id) != len(unit_id): |
| 32 | + raise ValueError("person_id and unit_id must have the same length") |
| 33 | + if person_id.isna().any(): |
| 34 | + raise ValueError("person_id cannot contain missing values") |
| 35 | + if unit_id.isna().any(): |
| 36 | + raise ValueError("unit_id cannot contain missing values") |
| 37 | + if person_id.duplicated().any(): |
| 38 | + duplicates = person_id[person_id.duplicated()].unique().tolist() |
| 39 | + raise ValueError(f"person_id must be unique, found duplicates: {duplicates}") |
| 40 | + |
| 41 | + object.__setattr__(self, "person_id", person_id.reset_index(drop=True)) |
| 42 | + object.__setattr__(self, "unit_id", unit_id.reset_index(drop=True)) |
| 43 | + |
| 44 | + if self.role is not None: |
| 45 | + role = _series(self.role, "role") |
| 46 | + if len(role) != len(person_id): |
| 47 | + raise ValueError("role must have the same length as person_id") |
| 48 | + object.__setattr__(self, "role", role.reset_index(drop=True)) |
| 49 | + |
| 50 | + @classmethod |
| 51 | + def from_frame( |
| 52 | + cls, |
| 53 | + frame: pd.DataFrame, |
| 54 | + unit_type: str, |
| 55 | + person_col: str = "person_id", |
| 56 | + unit_col: str = "unit_id", |
| 57 | + role_col: str | None = None, |
| 58 | + source: str | None = None, |
| 59 | + ) -> UnitPartition: |
| 60 | + """Build a partition from columns in a person-level frame.""" |
| 61 | + |
| 62 | + role = frame[role_col] if role_col is not None else None |
| 63 | + return cls( |
| 64 | + unit_type=unit_type, |
| 65 | + person_id=frame[person_col], |
| 66 | + unit_id=frame[unit_col], |
| 67 | + role=role, |
| 68 | + source=source, |
| 69 | + ) |
| 70 | + |
| 71 | + @property |
| 72 | + def n_persons(self) -> int: |
| 73 | + return len(self.person_id) |
| 74 | + |
| 75 | + @property |
| 76 | + def n_units(self) -> int: |
| 77 | + return int(self.unit_id.nunique()) |
| 78 | + |
| 79 | + def to_frame(self) -> pd.DataFrame: |
| 80 | + """Return person-level unit assignments.""" |
| 81 | + |
| 82 | + frame = pd.DataFrame( |
| 83 | + { |
| 84 | + "person_id": self.person_id, |
| 85 | + "unit_id": self.unit_id, |
| 86 | + } |
| 87 | + ) |
| 88 | + if self.role is not None: |
| 89 | + frame["role"] = self.role |
| 90 | + return frame |
| 91 | + |
| 92 | + def members(self) -> dict[Hashable, tuple[Hashable, ...]]: |
| 93 | + """Return unit members keyed by unit ID.""" |
| 94 | + |
| 95 | + frame = self.to_frame() |
| 96 | + grouped = frame.groupby("unit_id", sort=False)["person_id"] |
| 97 | + return {unit_id: tuple(group.tolist()) for unit_id, group in grouped} |
| 98 | + |
| 99 | + def unit_sizes(self) -> pd.Series: |
| 100 | + """Return the number of people in each unit.""" |
| 101 | + |
| 102 | + return self.unit_id.value_counts(sort=False) |
| 103 | + |
| 104 | + def relabel(self, prefix: str = "unit_") -> UnitPartition: |
| 105 | + """Return a copy with dense, stable unit IDs in encounter order.""" |
| 106 | + |
| 107 | + codes = pd.factorize(self.unit_id, sort=False)[0] |
| 108 | + unit_id = pd.Series([f"{prefix}{code + 1}" for code in codes]) |
| 109 | + return UnitPartition( |
| 110 | + unit_type=self.unit_type, |
| 111 | + person_id=self.person_id, |
| 112 | + unit_id=unit_id, |
| 113 | + role=self.role, |
| 114 | + source=self.source, |
| 115 | + ) |
| 116 | + |
| 117 | + |
| 118 | +@dataclass(frozen=True) |
| 119 | +class EgoUnitMembership: |
| 120 | + """A possibly-overlapping unit assignment for each focal person.""" |
| 121 | + |
| 122 | + unit_type: str |
| 123 | + focal_person_id: pd.Series |
| 124 | + member_person_id: pd.Series |
| 125 | + role: pd.Series | None = None |
| 126 | + source: str | None = None |
| 127 | + |
| 128 | + def __post_init__(self) -> None: |
| 129 | + focal = _series(self.focal_person_id, "focal_person_id") |
| 130 | + member = _series(self.member_person_id, "member_person_id") |
| 131 | + |
| 132 | + if len(focal) != len(member): |
| 133 | + raise ValueError("focal_person_id and member_person_id must align") |
| 134 | + if focal.isna().any() or member.isna().any(): |
| 135 | + raise ValueError("ego unit memberships cannot contain missing IDs") |
| 136 | + |
| 137 | + pairs = pd.DataFrame({"focal": focal, "member": member}) |
| 138 | + if pairs.duplicated().any(): |
| 139 | + raise ValueError("ego unit memberships cannot contain duplicate pairs") |
| 140 | + |
| 141 | + object.__setattr__(self, "focal_person_id", focal.reset_index(drop=True)) |
| 142 | + object.__setattr__(self, "member_person_id", member.reset_index(drop=True)) |
| 143 | + |
| 144 | + if self.role is not None: |
| 145 | + role = _series(self.role, "role") |
| 146 | + if len(role) != len(focal): |
| 147 | + raise ValueError("role must have the same length as memberships") |
| 148 | + object.__setattr__(self, "role", role.reset_index(drop=True)) |
| 149 | + |
| 150 | + @classmethod |
| 151 | + def from_mapping( |
| 152 | + cls, |
| 153 | + unit_type: str, |
| 154 | + memberships: Mapping[Hashable, Iterable[Hashable]], |
| 155 | + source: str | None = None, |
| 156 | + ) -> EgoUnitMembership: |
| 157 | + """Build overlapping units from focal-person membership sets.""" |
| 158 | + |
| 159 | + focal_ids: list[Hashable] = [] |
| 160 | + member_ids: list[Hashable] = [] |
| 161 | + for focal, members in memberships.items(): |
| 162 | + for member in members: |
| 163 | + focal_ids.append(focal) |
| 164 | + member_ids.append(member) |
| 165 | + return cls(unit_type, pd.Series(focal_ids), pd.Series(member_ids), source=source) |
| 166 | + |
| 167 | + def to_frame(self) -> pd.DataFrame: |
| 168 | + """Return membership rows keyed by focal person and member person.""" |
| 169 | + |
| 170 | + frame = pd.DataFrame( |
| 171 | + { |
| 172 | + "focal_person_id": self.focal_person_id, |
| 173 | + "member_person_id": self.member_person_id, |
| 174 | + } |
| 175 | + ) |
| 176 | + if self.role is not None: |
| 177 | + frame["role"] = self.role |
| 178 | + return frame |
| 179 | + |
| 180 | + def members_for(self, focal_person_id: Hashable) -> tuple[Hashable, ...]: |
| 181 | + frame = self.to_frame() |
| 182 | + members = frame.loc[ |
| 183 | + frame["focal_person_id"] == focal_person_id, "member_person_id" |
| 184 | + ] |
| 185 | + return tuple(members.tolist()) |
0 commit comments