Skip to content

Commit d45c8f6

Browse files
refactor: make partner correlation spec-driven instead of hardcoded
- Add scope='partner_correlated' option to AttributeSpec - Add correlation_rate field for explicit correlation probability - Remove hardcoded PARTNER_CORRELATED_ATTRIBUTES list - Update selector prompt to guide LLM on scope selection - correlate_partner_attribute now takes attr_type and correlation_rate - NPC partners always get correlated age regardless of scope
1 parent d42dfd8 commit d45c8f6

6 files changed

Lines changed: 173 additions & 113 deletions

File tree

extropy/core/models/population.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,13 @@ class AttributeSpec(BaseModel):
400400
"universal", "population_specific", "context_specific", "personality"
401401
] = Field(description="Category of attribute")
402402
description: str = Field(description="What this attribute represents")
403-
scope: Literal["individual", "household"] = Field(
403+
scope: Literal["individual", "household", "partner_correlated"] = Field(
404404
default="individual",
405-
description="Whether this attribute is sampled per-individual or shared across a household",
405+
description="individual: varies per person; household: shared across household members; partner_correlated: correlated between partners using assortative mating",
406+
)
407+
correlation_rate: float | None = Field(
408+
default=None,
409+
description="For partner_correlated scope: probability (0-1) that partner has same value. None uses type-specific defaults (age uses gaussian, race uses per-group rates).",
406410
)
407411
sampling: SamplingConfig
408412
grounding: GroundingInfo
@@ -670,16 +674,17 @@ class DiscoveredAttribute(BaseModel):
670674
default="independent",
671675
description="independent: sample directly; derived: zero-variance formula; conditional: probabilistic dependency",
672676
)
673-
scope: Literal["individual", "household"] = Field(
677+
scope: Literal["individual", "household", "partner_correlated"] = Field(
674678
default="individual",
675-
description="individual: varies per person; household: shared across household members",
679+
description="individual: varies per person; household: shared across household members; partner_correlated: correlated between partners",
680+
)
681+
correlation_rate: float | None = Field(
682+
default=None,
683+
description="For partner_correlated scope: probability (0-1) that partner has same value",
676684
)
677685
depends_on: list[str] = Field(default_factory=list)
678686

679687

680-
# hydrated attribute seems to be an extension of discovered attribute.
681-
682-
683688
class HydratedAttribute(BaseModel):
684689
"""An attribute with distribution data from research (Step 2).
685690
@@ -696,9 +701,13 @@ class HydratedAttribute(BaseModel):
696701
strategy: Literal["independent", "derived", "conditional"] = Field(
697702
default="independent", description="Sampling strategy determined in Step 1"
698703
)
699-
scope: Literal["individual", "household"] = Field(
704+
scope: Literal["individual", "household", "partner_correlated"] = Field(
700705
default="individual",
701-
description="individual: varies per person; household: shared across household members",
706+
description="individual: varies per person; household: shared across household members; partner_correlated: correlated between partners",
707+
)
708+
correlation_rate: float | None = Field(
709+
default=None,
710+
description="For partner_correlated scope: probability (0-1) that partner has same value",
702711
)
703712
depends_on: list[str] = Field(default_factory=list)
704713
sampling: SamplingConfig

extropy/population/sampler/core.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
correlate_partner_attribute,
3434
generate_dependents,
3535
estimate_household_count,
36-
PARTNER_CORRELATED_ATTRIBUTES,
3736
)
3837
from .modifiers import apply_modifiers_and_sample
3938
from ...utils.eval_safe import eval_formula, FormulaError
@@ -214,7 +213,7 @@ def _sample_population_independent(
214213

215214
def _generate_npc_partner(
216215
primary: dict[str, Any],
217-
household_attrs: set[str],
216+
attr_map: dict[str, AttributeSpec],
218217
categorical_options: dict[str, list[str]],
219218
rng: random.Random,
220219
config: HouseholdConfig,
@@ -223,37 +222,52 @@ def _generate_npc_partner(
223222
"""Generate a lightweight NPC partner profile for context.
224223
225224
Not a full agent — just enough for persona prompts and conversations.
225+
Uses attr.scope from the spec to determine which attributes to include.
226226
"""
227227
partner: dict[str, Any] = {}
228228

229-
if "age" in primary:
230-
partner["age"] = correlate_partner_attribute("age", primary["age"], rng, config)
229+
# Always include gender
231230
partner["gender"] = rng.choice(["male", "female"])
232231

233-
for attr in (
234-
"race_ethnicity",
235-
"education_level",
236-
"religious_affiliation",
237-
"political_orientation",
238-
):
239-
if attr in primary:
240-
correlated = correlate_partner_attribute(
241-
attr, primary[attr], rng, config, categorical_options.get(attr)
242-
)
243-
if correlated is not None:
244-
partner[attr] = correlated
232+
# Always correlate age if present (essential for NPC identity, regardless of scope)
233+
if "age" in primary:
234+
partner["age"] = correlate_partner_attribute(
235+
"age",
236+
"int",
237+
primary["age"],
238+
None, # Uses gaussian offset
239+
rng,
240+
config,
241+
)
245242

246-
# Shared household attrs
247-
for attr in household_attrs:
248-
if attr in primary:
249-
partner[attr] = primary[attr]
243+
# Process attributes based on their scope
244+
for attr_name, attr in attr_map.items():
245+
if attr_name not in primary or attr_name == "age":
246+
continue
247+
248+
if attr.scope == "household":
249+
# Shared: copy from primary
250+
partner[attr_name] = primary[attr_name]
251+
elif attr.scope == "partner_correlated":
252+
# Correlated: use assortative mating
253+
partner[attr_name] = correlate_partner_attribute(
254+
attr_name,
255+
attr.type,
256+
primary[attr_name],
257+
attr.correlation_rate,
258+
rng,
259+
config,
260+
available_options=categorical_options.get(attr_name),
261+
)
262+
# Individual scope: skip for NPC (not enough data to sample fully)
250263

251264
# Generate name for partner
252265
partner_age = partner.get("age")
253266
birth_decade = age_to_birth_decade(partner_age) if partner_age is not None else None
267+
ethnicity = partner.get("race_ethnicity") or partner.get("ethnicity") or partner.get("race")
254268
first_name, _ = generate_name(
255269
gender=partner["gender"],
256-
ethnicity=partner.get("race_ethnicity"),
270+
ethnicity=str(ethnicity) if ethnicity else None,
257271
birth_decade=birth_decade,
258272
seed=rng.randint(0, 2**31),
259273
name_config=name_config,
@@ -406,7 +420,7 @@ def _sample_population_households(
406420
# Partner is NPC context on the primary agent
407421
npc_partner = _generate_npc_partner(
408422
adult1,
409-
household_attrs,
423+
attr_map,
410424
categorical_options,
411425
rng,
412426
config,
@@ -502,10 +516,10 @@ def _sample_partner_agent(
502516
) -> dict[str, Any]:
503517
"""Sample a partner agent with correlated demographics.
504518
505-
- Household-scoped attributes are copied from the primary.
506-
- Correlated attributes (age, race, education, religion, politics)
507-
use assortative mating tables.
508-
- Everything else is sampled independently.
519+
Uses attr.scope from the spec to determine sampling behavior:
520+
- scope="household": copy from primary
521+
- scope="partner_correlated": use assortative mating correlation
522+
- scope="individual": sample independently
509523
"""
510524
if config is None:
511525
config = HouseholdConfig()
@@ -517,29 +531,21 @@ def _sample_partner_agent(
517531
continue
518532

519533
# Household-scoped: copy from primary
520-
if attr_name in household_attrs and attr_name in primary:
534+
if attr.scope == "household" and attr_name in primary:
521535
value = primary[attr_name]
522-
# Correlated: use partner correlation
523-
elif attr_name in PARTNER_CORRELATED_ATTRIBUTES and attr_name in primary:
524-
correlated = correlate_partner_attribute(
536+
# Partner-correlated: use assortative mating
537+
elif attr.scope == "partner_correlated" and attr_name in primary:
538+
value = correlate_partner_attribute(
525539
attr_name,
540+
attr.type,
526541
primary[attr_name],
542+
attr.correlation_rate,
527543
rng,
528544
config,
529545
available_options=categorical_options.get(attr_name),
530546
)
531-
if correlated is not None:
532-
value = correlated
533-
else:
534-
# Fallback: sample independently
535-
try:
536-
value = _sample_attribute(attr, rng, agent, stats)
537-
except FormulaError as e:
538-
raise SamplingError(
539-
f"Agent {index}: Failed to sample '{attr_name}': {e}"
540-
) from e
541547
else:
542-
# Independent sampling
548+
# Individual scope: sample independently
543549
try:
544550
value = _sample_attribute(attr, rng, agent, stats)
545551
except FormulaError as e:

extropy/population/sampler/households.py

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
All household composition rates and correlation tables are read from
44
a HouseholdConfig instance (populated by LLM research at spec time,
55
with US Census defaults as the safety net).
6+
7+
Attribute scopes are now spec-driven:
8+
- scope="household": shared across all household members
9+
- scope="partner_correlated": correlated between partners using assortative mating
10+
- scope="individual": sampled independently for each person
611
"""
712

813
from __future__ import annotations
@@ -15,32 +20,7 @@
1520
from ..names.generator import generate_name
1621

1722
if TYPE_CHECKING:
18-
from ...core.models.population import NameConfig
19-
20-
21-
# Attributes that are always shared within a household
22-
HOUSEHOLD_SHARED_ATTRIBUTES = [
23-
"state",
24-
"urban_rural",
25-
"household_income",
26-
"household_size",
27-
]
28-
29-
# Attributes correlated between partners (not copied, but biased)
30-
PARTNER_CORRELATED_ATTRIBUTES = [
31-
"age",
32-
"country",
33-
"race_ethnicity",
34-
"education_level",
35-
"religious_affiliation",
36-
"political_orientation",
37-
]
38-
39-
# Attributes sampled independently for each partner
40-
PARTNER_INDEPENDENT_ATTRIBUTES = [
41-
"personality",
42-
"occupation_sector",
43-
]
23+
from ...core.models.population import AttributeSpec, NameConfig
4424

4525

4626
def _age_bracket(age: int, config: HouseholdConfig) -> str:
@@ -92,20 +72,34 @@ def household_needs_kids(htype: HouseholdType) -> bool:
9272

9373
def correlate_partner_attribute(
9474
attr_name: str,
75+
attr_type: str,
9576
primary_value: Any,
77+
correlation_rate: float | None,
9678
rng: random.Random,
9779
config: HouseholdConfig,
9880
available_options: list[str] | None = None,
9981
) -> Any:
10082
"""Produce a correlated value for a partner based on the primary's value.
10183
102-
For categorical attributes, uses assortative mating rates to decide
103-
whether to copy or re-sample. For age, applies a Gaussian offset.
104-
105-
Returns the correlated value, or None if the attribute isn't in the
106-
correlation tables (caller should sample independently).
84+
Uses the correlation_rate from the attribute spec. Special handling:
85+
- age (int/float): Gaussian offset using config.partner_age_gap_mean/std
86+
- race_ethnicity-like attrs: Per-group rates from config.same_group_rates
87+
- Other categorical/boolean: Simple probability of same value
88+
89+
Args:
90+
attr_name: Name of the attribute
91+
attr_type: Type of the attribute (int, float, categorical, boolean)
92+
primary_value: The primary partner's value
93+
correlation_rate: Probability (0-1) that partner has same value, or None for defaults
94+
rng: Random number generator
95+
config: HouseholdConfig with default rates
96+
available_options: For categorical attrs, list of valid options to sample from
97+
98+
Returns:
99+
The correlated value for the partner.
107100
"""
108-
if attr_name == "age" and isinstance(primary_value, (int, float)):
101+
# Age uses gaussian offset, not simple correlation
102+
if attr_name == "age" and attr_type in ("int", "float"):
109103
partner_age = int(
110104
round(
111105
rng.gauss(
@@ -116,16 +110,8 @@ def correlate_partner_attribute(
116110
)
117111
return max(config.min_adult_age, partner_age)
118112

119-
if attr_name == "country":
120-
if rng.random() < config.same_country_rate:
121-
return primary_value
122-
if available_options:
123-
others = [o for o in available_options if o != primary_value]
124-
if others:
125-
return rng.choice(others)
126-
return primary_value
127-
128-
if attr_name == "race_ethnicity":
113+
# Race/ethnicity uses per-group rates from config
114+
if attr_name in ("race_ethnicity", "ethnicity", "race"):
129115
same_rate = config.same_group_rates.get(
130116
str(primary_value).lower(), config.default_same_group_rate
131117
)
@@ -137,17 +123,29 @@ def correlate_partner_attribute(
137123
return rng.choice(others)
138124
return primary_value
139125

140-
if attr_name in config.assortative_mating:
141-
correlation = config.assortative_mating[attr_name]
142-
if rng.random() < correlation:
126+
# Country uses same_country_rate from config if no explicit rate
127+
if attr_name == "country":
128+
rate = correlation_rate if correlation_rate is not None else config.same_country_rate
129+
if rng.random() < rate:
143130
return primary_value
144131
if available_options:
145132
others = [o for o in available_options if o != primary_value]
146133
if others:
147134
return rng.choice(others)
148135
return primary_value
149136

150-
return None # Not a correlated attribute
137+
# For all other attributes, use the explicit correlation_rate or a default
138+
rate = correlation_rate if correlation_rate is not None else config.default_same_group_rate
139+
if rng.random() < rate:
140+
return primary_value
141+
142+
# Sample a different value
143+
if available_options:
144+
others = [o for o in available_options if o != primary_value]
145+
if others:
146+
return rng.choice(others)
147+
148+
return primary_value
151149

152150

153151
def generate_dependents(

extropy/population/spec_builder/binder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def bind_constraints(
140140
category=attr.category,
141141
description=attr.description,
142142
scope=attr.scope,
143+
correlation_rate=attr.correlation_rate,
143144
sampling=filtered_sampling,
144145
grounding=attr.grounding,
145146
constraints=attr.constraints,

0 commit comments

Comments
 (0)