Skip to content

Commit d42dfd8

Browse files
feat: add country correlation for multi-country household sampling
- Add country to PARTNER_CORRELATED_ATTRIBUTES (95% same-country rate) - Add same_country_rate config field to HouseholdConfig - Auto-inject country attribute for multi-country geographies (world, continents, regions) - Detect patterns like 'east asia', 'europe', 'global', 'apac', etc.
1 parent 22bccf8 commit d42dfd8

4 files changed

Lines changed: 109 additions & 0 deletions

File tree

extropy/core/models/population.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ class HouseholdConfig(BaseModel):
106106
}
107107
)
108108
default_same_group_rate: float = 0.85
109+
# Partner correlation: same-country rate for international populations
110+
same_country_rate: float = 0.95
109111
assortative_mating: dict[str, float] = Field(
110112
default_factory=lambda: {
111113
"education_level": 0.6,

extropy/population/sampler/households.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# Attributes correlated between partners (not copied, but biased)
3030
PARTNER_CORRELATED_ATTRIBUTES = [
3131
"age",
32+
"country",
3233
"race_ethnicity",
3334
"education_level",
3435
"religious_affiliation",
@@ -115,6 +116,15 @@ def correlate_partner_attribute(
115116
)
116117
return max(config.min_adult_age, partner_age)
117118

119+
if attr_name == "country":
120+
if rng.random() < config.same_country_rate:
121+
return primary_value
122+
if available_options:
123+
others = [o for o in available_options if o != primary_value]
124+
if others:
125+
return rng.choice(others)
126+
return primary_value
127+
118128
if attr_name == "race_ethnicity":
119129
same_rate = config.same_group_rates.get(
120130
str(primary_value).lower(), config.default_same_group_rate

extropy/population/spec_builder/selector.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,65 @@
1111
from ...core.models import AttributeSpec, DiscoveredAttribute
1212

1313

14+
# Multi-country geography patterns that should trigger country attribute injection
15+
MULTI_COUNTRY_PATTERNS = [
16+
# Global
17+
"world",
18+
"global",
19+
"globe",
20+
"international",
21+
"worldwide",
22+
# Continents
23+
"africa",
24+
"asia",
25+
"europe",
26+
"north america",
27+
"south america",
28+
"latin america",
29+
"oceania",
30+
"antarctica",
31+
# Regions spanning multiple countries
32+
"east asia",
33+
"southeast asia",
34+
"south asia",
35+
"central asia",
36+
"middle east",
37+
"western europe",
38+
"eastern europe",
39+
"central europe",
40+
"northern europe",
41+
"southern europe",
42+
"sub-saharan africa",
43+
"north africa",
44+
"central america",
45+
"caribbean",
46+
"pacific",
47+
"nordic",
48+
"scandinavian",
49+
"balkan",
50+
"mediterranean",
51+
"gulf",
52+
"apac",
53+
"emea",
54+
"latam",
55+
"amer",
56+
]
57+
58+
59+
def _is_multi_country_geography(geography: str | None, description: str) -> bool:
60+
"""Detect if the population spans multiple countries."""
61+
if not geography and not description:
62+
return False
63+
64+
text = f"{geography or ''} {description}".lower()
65+
66+
for pattern in MULTI_COUNTRY_PATTERNS:
67+
if pattern in text:
68+
return True
69+
70+
return False
71+
72+
1473
# JSON schema for attribute selection response
1574
ATTRIBUTE_SELECTION_SCHEMA = {
1675
"type": "object",
@@ -334,4 +393,22 @@ def trait_already_exists(trait: str, existing: set[str]) -> bool:
334393
)
335394
)
336395

396+
# Inject country attribute for multi-country geographies
397+
if _is_multi_country_geography(geography, description):
398+
existing_names = {a.name for a in attributes}
399+
if "country" not in existing_names:
400+
# Insert at the beginning (universal demographic)
401+
attributes.insert(
402+
0,
403+
DiscoveredAttribute(
404+
name="country",
405+
type="categorical",
406+
category="universal",
407+
description="Country of residence",
408+
strategy="independent",
409+
scope="household",
410+
depends_on=[],
411+
),
412+
)
413+
337414
return attributes

tests/test_household_sampling.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,23 @@ def test_partner_age_correlation(self):
422422
avg_diff = sum(age_diffs) / len(age_diffs)
423423
# Average age gap should be small (< 10 years)
424424
assert avg_diff < 10, f"Average age gap {avg_diff:.1f} too large"
425+
426+
def test_country_correlation(self):
427+
"""Country should be correlated with high same-country rate (~95%)."""
428+
rng = random.Random(42)
429+
config = HouseholdConfig(same_country_rate=0.95)
430+
countries = ["USA", "India", "UK", "Japan", "Brazil"]
431+
432+
same_country = 0
433+
total = 1000
434+
for _ in range(total):
435+
primary_country = rng.choice(countries)
436+
partner_country = correlate_partner_attribute(
437+
"country", primary_country, rng, config, countries
438+
)
439+
if partner_country == primary_country:
440+
same_country += 1
441+
442+
rate = same_country / total
443+
# Should be close to 0.95 (within statistical margin)
444+
assert 0.90 < rate < 0.99, f"Same-country rate {rate:.2%} out of expected range"

0 commit comments

Comments
 (0)