diff --git a/extropy/cli/commands/network.py b/extropy/cli/commands/network.py index c81a187..a564baf 100644 --- a/extropy/cli/commands/network.py +++ b/extropy/cli/commands/network.py @@ -324,8 +324,8 @@ def network_command( "auto_save_generated_config": auto_save_generated_config, "quarantine_suffix": quarantine_suffix, } - config = ( - config.model_copy(update=base_updates).apply_quality_profile_defaults(force=True) + config = config.model_copy(update=base_updates).apply_quality_profile_defaults( + force=True ) advanced_updates = {} @@ -541,7 +541,9 @@ def do_generation(): quality_meta = result.meta.get("quality", {}) accepted = bool(quality_meta.get("accepted", True)) - strict_failed = config.topology_gate == "strict" and not accepted and len(agents) >= 50 + strict_failed = ( + config.topology_gate == "strict" and not accepted and len(agents) >= 50 + ) # Save canonical output to study DB (or quarantine on strict failure) console.print() @@ -583,15 +585,15 @@ def do_generation(): console.print( "[yellow]![/yellow] Topology gate strict failed. Saved quarantined artifact; canonical network not overwritten." ) - console.print( - f"[yellow]![/yellow] Quarantined network_id={target_network_id}" - ) + console.print(f"[yellow]![/yellow] Quarantined network_id={target_network_id}") console.print( f"[red]✗[/red] Failed gates with best metrics: {quality_meta.get('best_metrics', {})}" ) if gate_deltas: console.print(f"[dim]Gate deltas: {gate_deltas}[/dim]") - console.print(f"[dim]inspect via: extropy inspect network-status --study-db {study_db} --network-run-id {network_run_id}[/dim]") + console.print( + f"[dim]inspect via: extropy inspect network-status --study-db {study_db} --network-run-id {network_run_id}[/dim]" + ) raise typer.Exit(1) if strict_failed and not config.allow_quarantine: console.print( diff --git a/extropy/cli/commands/sample.py b/extropy/cli/commands/sample.py index bbea94c..5969c37 100644 --- a/extropy/cli/commands/sample.py +++ b/extropy/cli/commands/sample.py @@ -285,6 +285,33 @@ def on_progress(current: int, total: int): ) out.blank() + # Household report (if applicable) + households = getattr(result, "_households", []) + if households and report and not get_json_mode(): + out.header("HOUSEHOLD REPORT") + type_counts: dict[str, int] = {} + for hh in households: + ht = hh["household_type"] + type_counts[ht] = type_counts.get(ht, 0) + 1 + hh_rows = [[htype, str(cnt)] for htype, cnt in sorted(type_counts.items())] + out.table( + "Household Types", + ["Type", "Count"], + hh_rows, + styles=["cyan", None], + ) + out.text( + f" Total households: {len(households)}, Total agents: {len(result.agents)}" + ) + out.blank() + + if households and get_json_mode(): + out.set_data("household_count", len(households)) + out.set_data( + "household_type_distribution", + result.meta.get("household_type_distribution", {}), + ) + # Save to canonical DB out.blank() if not get_json_mode(): @@ -301,6 +328,12 @@ def on_progress(current: int, total: int): meta=result.meta, seed=result.meta.get("seed"), ) + if households: + db.save_households( + population_id=population_id, + sample_run_id=sample_run_id, + households=households, + ) else: with open_study_db(study_db) as db: db.save_population_spec( @@ -314,6 +347,12 @@ def on_progress(current: int, total: int): meta=result.meta, seed=result.meta.get("seed"), ) + if households: + db.save_households( + population_id=population_id, + sample_run_id=sample_run_id, + households=households, + ) elapsed = time.time() - start_time diff --git a/extropy/core/models/__init__.py b/extropy/core/models/__init__.py index 642090b..063816c 100644 --- a/extropy/core/models/__init__.py +++ b/extropy/core/models/__init__.py @@ -34,6 +34,10 @@ PopulationSpec, # Pipeline types SufficiencyResult, + # Household models + HouseholdType, + Dependent, + STANDARD_PERSONALITY_ATTRIBUTES, ) # Validation models (shared across population and scenario) @@ -139,6 +143,10 @@ "PopulationSpec", # Population - Pipeline types "SufficiencyResult", + # Population - Household models + "HouseholdType", + "Dependent", + "STANDARD_PERSONALITY_ATTRIBUTES", # Scenario - Event "EventType", "Event", diff --git a/extropy/core/models/network.py b/extropy/core/models/network.py index 0542f29..9574015 100644 --- a/extropy/core/models/network.py +++ b/extropy/core/models/network.py @@ -28,10 +28,12 @@ class Edge(BaseModel): edge_type: str bidirectional: bool = True influence_weight: dict[str, float] = Field(default_factory=dict) + structural: bool = False + context: str | None = None # "household", "workplace", "neighborhood", etc. def to_dict(self) -> dict[str, Any]: """Convert to dictionary for JSON serialization.""" - return { + d: dict[str, Any] = { "source": self.source, "target": self.target, "weight": round(self.weight, 4), @@ -46,6 +48,11 @@ def to_dict(self) -> dict[str, Any]: ), }, } + if self.structural: + d["structural"] = True + if self.context: + d["context"] = self.context + return d class NodeMetrics(BaseModel): diff --git a/extropy/core/models/population.py b/extropy/core/models/population.py index 6715dae..ba618ac 100644 --- a/extropy/core/models/population.py +++ b/extropy/core/models/population.py @@ -13,6 +13,7 @@ from collections import defaultdict from datetime import datetime +from enum import Enum from pathlib import Path from typing import Literal @@ -20,6 +21,42 @@ from pydantic import BaseModel, Field +# ============================================================================= +# Household Models +# ============================================================================= + + +class HouseholdType(str, Enum): + SINGLE = "single" + COUPLE = "couple" + SINGLE_PARENT = "single_parent" + COUPLE_WITH_KIDS = "couple_with_kids" + MULTI_GENERATIONAL = "multi_generational" + + +class Dependent(BaseModel): + """NPC dependent (child, elderly parent).""" + + name: str + age: int + gender: str + relationship: str # "son", "daughter", "mother", etc. + school_status: str | None = None # "home", "elementary", "middle_school", etc. + + +# Standard personality attributes that spec builders should include. +# `conformity` (float, 0-1, correlated with agreeableness) is consumed by +# Phase C for threshold behavior in simulation. +STANDARD_PERSONALITY_ATTRIBUTES = [ + "neuroticism", + "extraversion", + "openness", + "conscientiousness", + "agreeableness", + "conformity", +] + + # ============================================================================= # Grounding Information # ============================================================================= @@ -247,6 +284,10 @@ class AttributeSpec(BaseModel): "universal", "population_specific", "context_specific", "personality" ] = Field(description="Category of attribute") description: str = Field(description="What this attribute represents") + scope: Literal["individual", "household"] = Field( + default="individual", + description="Whether this attribute is sampled per-individual or shared across a household", + ) sampling: SamplingConfig grounding: GroundingInfo constraints: list[Constraint] = Field(default_factory=list) diff --git a/extropy/population/network/config.py b/extropy/population/network/config.py index 36288e2..bb09827 100644 --- a/extropy/population/network/config.py +++ b/extropy/population/network/config.py @@ -161,6 +161,13 @@ class NetworkConfig(BaseModel): target_clustering_tolerance: float = 0.08 bridge_budget_fraction: float = 0.08 swap_passes: int = 3 + degree_distribution_target: Literal["uniform", "power_law"] | None = None + power_law_exponent: float = 2.5 # only used when target is power_law + identity_clustering_attributes: list[str] = Field( + default_factory=list, + description="Attributes for in-group edge density boost (e.g., political_orientation, religious_affiliation)", + ) + identity_clustering_boost: float = 1.5 # multiplier on intra-group edge probability auto_save_generated_config: bool = True allow_quarantine: bool = True quarantine_suffix: str = "rejected" diff --git a/extropy/population/network/generator.py b/extropy/population/network/generator.py index e064681..77cbd33 100644 --- a/extropy/population/network/generator.py +++ b/extropy/population/network/generator.py @@ -1035,7 +1035,9 @@ def _build_adjacency_from_edges( return adjacency -def _acceptance_bounds(config: NetworkConfig) -> tuple[float, float, float, float, float]: +def _acceptance_bounds( + config: NetworkConfig, +) -> tuple[float, float, float, float, float]: degree_delta = max(1.0, config.avg_degree * config.target_degree_tolerance_pct) deg_min = max(1.0, config.avg_degree - degree_delta) deg_max = config.avg_degree + degree_delta @@ -1222,7 +1224,7 @@ def _ensure_intra_similarity_coverage( members_by_community[c].append(idx) intra_counts = [0] * n - for (i, j) in similarities.keys(): + for i, j in similarities.keys(): if communities[i] == communities[j]: intra_counts[i] += 1 intra_counts[j] += 1 @@ -1626,11 +1628,16 @@ def _apply_rewiring( edge_set: set[tuple[str, str]], config: NetworkConfig, rng: random.Random, + protected_pairs: set[tuple[str, str]] | None = None, ) -> tuple[list[Edge], set[tuple[str, str]], int]: """Apply Watts-Strogatz rewiring for small-world properties. + Structural edges in protected_pairs are never rewired. + Returns (edges, edge_set, rewired_count). """ + if protected_pairs is None: + protected_pairs = set() n = len(agents) n_rewire = int(len(edges) * config.rewire_prob) rewired_count = 0 @@ -1643,6 +1650,10 @@ def _apply_rewiring( edge_idx = rng.randint(0, len(edges) - 1) old_edge = edges[edge_idx] + # Skip structural/protected edges + if (old_edge.source, old_edge.target) in protected_pairs: + continue + source_idx = id_to_idx.get(old_edge.source) if source_idx is None: continue @@ -1737,6 +1748,174 @@ def _compute_modularity_fast( return q +def _generate_structural_edges( + agents: list[dict[str, Any]], + agent_ids: list[str], + rng: random.Random, +) -> list[Edge]: + """Generate deterministic structural edges from agent attributes. + + Creates edges based on shared household, partner status, workplace, + neighborhood, congregation, and school-parent connections. + """ + id_to_idx = {aid: i for i, aid in enumerate(agent_ids)} + edges: list[Edge] = [] + added: set[tuple[int, int]] = set() + + def _add(i: int, j: int, etype: str, weight: float, context: str) -> None: + pair = (min(i, j), max(i, j)) + if pair in added or i == j: + return + added.add(pair) + edges.append( + Edge( + source=agent_ids[i], + target=agent_ids[j], + weight=weight, + edge_type=etype, + structural=True, + context=context, + ) + ) + + # Build indexes for batch matching + household_map: dict[str, list[int]] = {} + sector_state_map: dict[tuple[str, str], list[int]] = {} + state_urban_map: dict[tuple[str, str], list[int]] = {} + religion_state_map: dict[tuple[str, str], list[int]] = {} + school_parent_map: dict[tuple[str, str], list[int]] = {} + + for idx, agent in enumerate(agents): + hh_id = agent.get("household_id") + if hh_id: + household_map.setdefault(hh_id, []).append(idx) + + sector = agent.get("occupation_sector") + state = agent.get("state") + urban = agent.get("urban_rural") + + if sector and state: + sector_state_map.setdefault((sector, state), []).append(idx) + if state and urban: + state_urban_map.setdefault((state, urban), []).append(idx) + + religion = agent.get("religious_affiliation") + if ( + religion + and state + and religion.lower() not in ("none", "atheist", "agnostic") + ): + religion_state_map.setdefault((religion, state), []).append(idx) + + # School parents: agents with school-age dependents + dependents = agent.get("dependents", []) + has_school_kid = any( + isinstance(d, dict) + and d.get("school_status") in ("elementary", "middle_school", "high_school") + for d in dependents + ) + if has_school_kid and state and urban: + school_parent_map.setdefault((state, urban), []).append(idx) + + # 1. Partner edges (weight 1.0, max 1 per agent) + for idx, agent in enumerate(agents): + partner_id = agent.get("partner_id") + if partner_id and partner_id in id_to_idx: + j = id_to_idx[partner_id] + _add(idx, j, "partner", 1.0, "household") + + # 2. Household edges (weight 0.9, all members in same household) + for members in household_map.values(): + for i_pos in range(len(members)): + for j_pos in range(i_pos + 1, len(members)): + _add(members[i_pos], members[j_pos], "household", 0.9, "household") + + # 3. Coworker edges (weight 0.6, capped at ~8 per agent) + _MAX_COWORKER = 8 + coworker_count: dict[int, int] = {} + for pool in sector_state_map.values(): + if len(pool) < 2: + continue + shuffled = list(pool) + rng.shuffle(shuffled) + for i_pos, i in enumerate(shuffled): + if coworker_count.get(i, 0) >= _MAX_COWORKER: + continue + for j in shuffled[i_pos + 1 :]: + if coworker_count.get(j, 0) >= _MAX_COWORKER: + continue + if coworker_count.get(i, 0) >= _MAX_COWORKER: + break + _add(i, j, "coworker", 0.6, "workplace") + coworker_count[i] = coworker_count.get(i, 0) + 1 + coworker_count[j] = coworker_count.get(j, 0) + 1 + + # 4. Neighbor edges (weight 0.4, capped at ~4, age within 15yr) + _MAX_NEIGHBOR = 4 + neighbor_count: dict[int, int] = {} + for pool in state_urban_map.values(): + if len(pool) < 2: + continue + shuffled = list(pool) + rng.shuffle(shuffled) + for i_pos, i in enumerate(shuffled): + if neighbor_count.get(i, 0) >= _MAX_NEIGHBOR: + continue + age_i = agents[i].get("age", 40) + for j in shuffled[i_pos + 1 :]: + if neighbor_count.get(j, 0) >= _MAX_NEIGHBOR: + continue + if neighbor_count.get(i, 0) >= _MAX_NEIGHBOR: + break + age_j = agents[j].get("age", 40) + if abs(age_i - age_j) <= 15: + _add(i, j, "neighbor", 0.4, "neighborhood") + neighbor_count[i] = neighbor_count.get(i, 0) + 1 + neighbor_count[j] = neighbor_count.get(j, 0) + 1 + + # 5. Congregation edges (weight 0.4, capped at ~4) + _MAX_CONGREGATION = 4 + congregation_count: dict[int, int] = {} + for pool in religion_state_map.values(): + if len(pool) < 2: + continue + shuffled = list(pool) + rng.shuffle(shuffled) + for i_pos, i in enumerate(shuffled): + if congregation_count.get(i, 0) >= _MAX_CONGREGATION: + continue + for j in shuffled[i_pos + 1 :]: + if congregation_count.get(j, 0) >= _MAX_CONGREGATION: + continue + if congregation_count.get(i, 0) >= _MAX_CONGREGATION: + break + _add(i, j, "congregation", 0.4, "congregation") + congregation_count[i] = congregation_count.get(i, 0) + 1 + congregation_count[j] = congregation_count.get(j, 0) + 1 + + # 6. School parent edges (weight 0.35, capped at ~3) + _MAX_SCHOOL_PARENT = 3 + school_count: dict[int, int] = {} + for pool in school_parent_map.values(): + if len(pool) < 2: + continue + shuffled = list(pool) + rng.shuffle(shuffled) + for i_pos, i in enumerate(shuffled): + if school_count.get(i, 0) >= _MAX_SCHOOL_PARENT: + continue + for j in shuffled[i_pos + 1 :]: + if school_count.get(j, 0) >= _MAX_SCHOOL_PARENT: + continue + if school_count.get(i, 0) >= _MAX_SCHOOL_PARENT: + break + _add(i, j, "school_parent", 0.35, "school") + school_count[i] = school_count.get(i, 0) + 1 + school_count[j] = school_count.get(j, 0) + 1 + + return edges + + def _generate_network_single_pass( agents: list[dict], agent_ids: list[str], @@ -1862,7 +2041,9 @@ def emit_progress( return if study_db_file is not None and network_run_id: if isinstance(message, dict): - message_text = json.dumps(message, sort_keys=True, separators=(",", ":")) + message_text = json.dumps( + message, sort_keys=True, separators=(",", ":") + ) else: message_text = message or stage with open_study_db(study_db_file) as db: @@ -1876,7 +2057,9 @@ def emit_progress( def _stage_plan() -> list[dict[str, Any]]: if config.candidate_mode == "exact": - return [{"name": "exact", "candidate_mode": "exact", "pool_multiplier": 0.0}] + return [ + {"name": "exact", "candidate_mode": "exact", "pool_multiplier": 0.0} + ] base_mult = max(6.0, config.candidate_pool_multiplier) if config.quality_profile == "fast": return [ @@ -1937,7 +2120,10 @@ def _stage_plan() -> list[dict[str, Any]]: stage_summaries: list[dict[str, Any]] = [] calibration_step = 0 calibration_total = max( - 1, len(stage_plan) * config.calibration_restarts * config.max_calibration_iterations + 1, + len(stage_plan) + * config.calibration_restarts + * config.max_calibration_iterations, ) for stage_idx, stage in enumerate(stage_plan): @@ -1975,7 +2161,9 @@ def _stage_plan() -> list[dict[str, Any]]: if candidate_map is None: candidate_mode = "exact" elif stage.get("hybrid_expand"): - expanded_pool = int(max(config.avg_degree * 20, stage_cfg.min_candidate_pool)) + expanded_pool = int( + max(config.avg_degree * 20, stage_cfg.min_candidate_pool) + ) candidate_map = _expand_candidate_map_undercovered( candidate_map=candidate_map, n=n, @@ -2016,14 +2204,15 @@ def _stage_plan() -> list[dict[str, Any]]: if should_resume_similarity: if checkpoint_file is None or not checkpoint_file.exists(): raise ValueError(f"Checkpoint not found: {checkpoint_file}") - similarities, start_row, completed_chunk_starts = _load_similarity_checkpoint( - checkpoint_file, checkpoint_signature + similarities, start_row, completed_chunk_starts = ( + _load_similarity_checkpoint(checkpoint_file, checkpoint_signature) ) if checkpoint_job_id and checkpoint_file is not None: with open_study_db(checkpoint_file) as db: db.mark_similarity_job_running(checkpoint_job_id) use_parallel_similarity = stage_cfg.similarity_workers > 1 + def similarity_progress(_stage: str, current: int, total: int) -> None: emit_progress( "Computing similarities", @@ -2082,6 +2271,31 @@ def similarity_progress(_stage: str, current: int, total: int) -> None: final_blocking_attrs = blocking_attrs if candidate_mode == "blocked" else [] final_similarity_pairs = len(similarities) + # Apply identity clustering boost if configured + if config.identity_clustering_attributes and similarities: + id_attrs = config.identity_clustering_attributes + boost = config.identity_clustering_boost + boosted_count = 0 + for (i, j), sim in list(similarities.items()): + shared = 0 + for attr in id_attrs: + va = agents[i].get(attr) + vb = agents[j].get(attr) + if va is not None and vb is not None and va == vb: + shared += 1 + if shared > 0: + # Apply multiplicative boost, capped at 1.0 + new_sim = min(1.0, sim * (boost**shared)) + if new_sim != sim: + similarities[(i, j)] = new_sim + boosted_count += 1 + if boosted_count > 0: + logger.debug( + "Identity clustering: boosted %d pairs on attributes %s", + boosted_count, + id_attrs, + ) + if needs_escalation and stage_idx < len(stage_plan) - 1: stage_summaries.append( { @@ -2127,7 +2341,10 @@ def similarity_progress(_stage: str, current: int, total: int) -> None: force_db=True, ) - if community_diag.get("low_signal", 0.0) >= 1.0 and stage_idx < len(stage_plan) - 1: + if ( + community_diag.get("low_signal", 0.0) >= 1.0 + and stage_idx < len(stage_plan) - 1 + ): stage_summaries.append( { "stage": stage_name, @@ -2157,7 +2374,8 @@ def similarity_progress(_stage: str, current: int, total: int) -> None: with open_study_db(study_db_file) as db: calibration_run_id = db.create_network_calibration_run( network_run_id=network_run_id, - restart_index=(stage_idx * config.calibration_restarts) + restart, + restart_index=(stage_idx * config.calibration_restarts) + + restart, seed=restart_seed, ) @@ -2200,11 +2418,15 @@ def similarity_progress(_stage: str, current: int, total: int) -> None: ) if should_bridge_swap: if modularity > (mod_max + 0.02): - swap_budget = max(4, int(len(edges) * config.bridge_budget_fraction)) + swap_budget = max( + 4, int(len(edges) * config.bridge_budget_fraction) + ) else: # Connectivity repair should be conservative so we do not # collapse modular structure while bridging components. - ratio = min(config.bridge_budget_fraction * 0.25, lcc_deficit * 0.20) + ratio = min( + config.bridge_budget_fraction * 0.25, lcc_deficit * 0.20 + ) swap_budget = max(1, int(len(edges) * ratio)) edges = _apply_bridge_swaps( edges, @@ -2376,7 +2598,9 @@ def similarity_progress(_stage: str, current: int, total: int) -> None: if stage_accepted: break - emit_progress("Calibrating network", calibration_total, calibration_total, force_db=True) + emit_progress( + "Calibrating network", calibration_total, calibration_total, force_db=True + ) edges = best_edges or [] # Rebuild edge_set from best edges @@ -2385,10 +2609,50 @@ def similarity_progress(_stage: str, current: int, total: int) -> None: edge_set.add((edge.source, edge.target)) edge_set.add((edge.target, edge.source)) - # Step 5: Watts-Strogatz rewiring + # Step 4b: Generate structural edges and merge (protected from pruning/rewiring) + structural_edges = _generate_structural_edges(agents, agent_ids, rng) + structural_pairs: set[tuple[str, str]] = set() + structural_added = 0 + for se in structural_edges: + if (se.source, se.target) not in edge_set: + edges.append(se) + edge_set.add((se.source, se.target)) + edge_set.add((se.target, se.source)) + structural_added += 1 + else: + # Mark existing edge as structural if it matches + for existing in edges: + if (existing.source == se.source and existing.target == se.target) or ( + existing.source == se.target and existing.target == se.source + ): + existing.structural = True + existing.context = se.context + break + structural_pairs.add((se.source, se.target)) + structural_pairs.add((se.target, se.source)) + + if structural_added > 0: + emit_progress( + "Structural edges", + structural_added, + structural_added, + message={ + "structural_added": structural_added, + "total_structural": len(structural_edges), + }, + force_db=True, + ) + + # Step 5: Watts-Strogatz rewiring (skip structural edges) emit_progress("Rewiring edges", 0, len(edges), force_db=True) edges, edge_set, rewired_count = _apply_rewiring( - agents, agent_ids, edges, edge_set, config, rng + agents, + agent_ids, + edges, + edge_set, + config, + rng, + protected_pairs=structural_pairs, ) emit_progress("Rewiring edges", len(edges), len(edges), force_db=True) @@ -2439,6 +2703,8 @@ def similarity_progress(_stage: str, current: int, total: int) -> None: }, "ladder_stages": stage_summaries, }, + "structural_edge_count": len(structural_edges), + "structural_edges_added": structural_added, "resume_calibration_requested": resume_calibration, "generated_at": datetime.now().isoformat(), } diff --git a/extropy/population/sampler/core.py b/extropy/population/sampler/core.py index fed63a9..b1909c0 100644 --- a/extropy/population/sampler/core.py +++ b/extropy/population/sampler/core.py @@ -2,6 +2,10 @@ The sampler is a generic spec interpreter - it doesn't know about surgeons or farmers, it just executes whatever spec it's given. + +Supports two modes: +- Independent sampling (legacy): each agent sampled independently +- Household sampling: agents are grouped into households with correlated demographics """ import json @@ -20,6 +24,15 @@ ) from ...utils.callbacks import ItemProgressCallback from .distributions import sample_distribution, coerce_to_type +from .households import ( + sample_household_type, + household_needs_partner, + household_needs_kids, + correlate_partner_attribute, + generate_dependents, + estimate_household_count, + PARTNER_CORRELATED_ATTRIBUTES, +) from .modifiers import apply_modifiers_and_sample from ...utils.eval_safe import eval_formula, FormulaError @@ -32,6 +45,11 @@ class SamplingError(Exception): pass +def _has_household_attributes(spec: PopulationSpec) -> bool: + """Check if the spec has household-scoped attributes, indicating household mode.""" + return any(attr.scope == "household" for attr in spec.attributes) + + def sample_population( spec: PopulationSpec, count: int | None = None, @@ -41,6 +59,10 @@ def sample_population( """ Generate agents from a PopulationSpec. + If the spec contains household-scoped attributes, agents are sampled in + household units with correlated demographics between partners. Otherwise, + agents are sampled independently (legacy behavior). + Args: spec: The population specification to sample from count: Number of agents to generate (defaults to spec.meta.size) @@ -89,32 +111,262 @@ def sample_population( attr.name: [] for attr in spec.attributes if attr.type in ("int", "float") } - agents: list[dict[str, Any]] = [] + use_households = _has_household_attributes(spec) - for i in range(n): - agent = _sample_single_agent( - spec, attr_map, rng, i, id_width, stats, numeric_values + if use_households: + agents, households = _sample_population_households( + spec, attr_map, rng, n, id_width, stats, numeric_values, on_progress ) - agents.append(agent) - - if on_progress: - on_progress(i + 1, n) + else: + agents = _sample_population_independent( + spec, attr_map, rng, n, id_width, stats, numeric_values, on_progress + ) + households = [] # Compute final statistics - _finalize_stats(stats, numeric_values, n) + _finalize_stats(stats, numeric_values, len(agents)) # Check expression constraints _check_expression_constraints(spec, agents, stats) # Build metadata - meta = { + meta: dict[str, Any] = { "spec": spec.meta.description, - "count": n, + "count": len(agents), "seed": seed, "generated_at": datetime.now().isoformat(), } + if households: + meta["household_count"] = len(households) + meta["household_mode"] = True + # Household type distribution + type_counts: dict[str, int] = {} + for hh in households: + ht = hh["household_type"] + type_counts[ht] = type_counts.get(ht, 0) + 1 + meta["household_type_distribution"] = type_counts + + result = SamplingResult(agents=agents, meta=meta, stats=stats) + # Attach households for DB persistence (not part of SamplingResult model, + # but accessible as an ad-hoc attribute for save_sample_result) + result._households = households # type: ignore[attr-defined] + return result + + +def _sample_population_independent( + spec: PopulationSpec, + attr_map: dict[str, AttributeSpec], + rng: random.Random, + n: int, + id_width: int, + stats: SamplingStats, + numeric_values: dict[str, list[float]], + on_progress: ItemProgressCallback | None = None, +) -> list[dict[str, Any]]: + """Sample N agents independently (legacy path).""" + agents: list[dict[str, Any]] = [] + for i in range(n): + agent = _sample_single_agent( + spec, attr_map, rng, i, id_width, stats, numeric_values + ) + agents.append(agent) + if on_progress: + on_progress(i + 1, n) + return agents + + +def _sample_population_households( + spec: PopulationSpec, + attr_map: dict[str, AttributeSpec], + rng: random.Random, + target_n: int, + id_width: int, + stats: SamplingStats, + numeric_values: dict[str, list[float]], + on_progress: ItemProgressCallback | None = None, +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Sample agents in household units with correlated demographics. + + Returns (agents, households) where households is a list of household + metadata dicts for DB persistence. + """ + num_households = estimate_household_count(target_n) + hh_id_width = len(str(num_households - 1)) + + agents: list[dict[str, Any]] = [] + households: list[dict[str, Any]] = [] + agent_index = 0 + + # Identify household-scoped attributes and collect categorical options + household_attrs = { + attr.name for attr in spec.attributes if attr.scope == "household" + } + categorical_options: dict[str, list[str]] = {} + for attr in spec.attributes: + if attr.type == "categorical" and attr.sampling.distribution: + dist = attr.sampling.distribution + if hasattr(dist, "options"): + categorical_options[attr.name] = dist.options + + for hh_idx in range(num_households): + if agent_index >= target_n: + break + + household_id = f"household_{hh_idx:0{hh_id_width}d}" + + # Sample Adult 1 (primary) + adult1 = _sample_single_agent( + spec, attr_map, rng, agent_index, id_width, stats, numeric_values + ) + adult1_age = adult1.get("age", 35) + agent_index += 1 + + # Determine household type + htype = sample_household_type(adult1_age, rng) + + has_partner = household_needs_partner(htype) + has_kids = household_needs_kids(htype) + num_adults = 2 if has_partner else 1 + + # Determine household_size from agent if present, else estimate + household_size = adult1.get( + "household_size", num_adults + (1 if has_kids else 0) + ) + if isinstance(household_size, (int, float)): + household_size = max(num_adults, int(household_size)) + else: + household_size = num_adults + (1 if has_kids else 0) + + # Annotate Adult 1 with household fields + adult1["household_id"] = household_id + adult1["household_role"] = "adult_primary" + + adult_ids = [adult1["_id"]] + + if has_partner and agent_index < target_n: + # Sample Adult 2 with correlated demographics + adult2 = _sample_partner_agent( + spec, + attr_map, + rng, + agent_index, + id_width, + stats, + numeric_values, + adult1, + household_attrs, + categorical_options, + ) + adult2["household_id"] = household_id + adult2["household_role"] = "adult_secondary" + # Partners share a surname + if adult1.get("last_name"): + adult2["last_name"] = adult1["last_name"] + adult2["partner_id"] = adult1["_id"] + adult1["partner_id"] = adult2["_id"] + adult_ids.append(adult2["_id"]) + agent_index += 1 + else: + adult1["partner_id"] = None + + # Generate NPC dependents + dependents = generate_dependents( + htype, household_size, num_adults, adult1_age, rng + ) + dep_dicts = [d.model_dump() for d in dependents] + + # Attach dependents to all adults + adult1["dependents"] = dep_dicts + agents.append(adult1) + + if has_partner and len(adult_ids) > 1: + adult2["dependents"] = dep_dicts + agents.append(adult2) + + # Build household record + shared_attrs = {} + for attr_name in household_attrs: + if attr_name in adult1: + shared_attrs[attr_name] = adult1[attr_name] + + households.append( + { + "id": household_id, + "household_type": htype.value, + "adult_ids": adult_ids, + "dependent_data": dep_dicts, + "shared_attributes": shared_attrs, + } + ) + + if on_progress: + on_progress(min(agent_index, target_n), target_n) + + return agents, households + + +def _sample_partner_agent( + spec: PopulationSpec, + attr_map: dict[str, AttributeSpec], + rng: random.Random, + index: int, + id_width: int, + stats: SamplingStats, + numeric_values: dict[str, list[float]], + primary: dict[str, Any], + household_attrs: set[str], + categorical_options: dict[str, list[str]], +) -> dict[str, Any]: + """Sample a partner agent with correlated demographics. - return SamplingResult(agents=agents, meta=meta, stats=stats) + - Household-scoped attributes are copied from the primary. + - Correlated attributes (age, race, education, religion, politics) + use assortative mating tables. + - Everything else is sampled independently. + """ + agent: dict[str, Any] = {"_id": f"agent_{index:0{id_width}d}"} + + for attr_name in spec.sampling_order: + attr = attr_map.get(attr_name) + if attr is None: + continue + + # Household-scoped: copy from primary + if attr_name in household_attrs and attr_name in primary: + value = primary[attr_name] + # Correlated: use partner correlation + elif attr_name in PARTNER_CORRELATED_ATTRIBUTES and attr_name in primary: + correlated = correlate_partner_attribute( + attr_name, + primary[attr_name], + rng, + available_options=categorical_options.get(attr_name), + ) + if correlated is not None: + value = correlated + else: + # Fallback: sample independently + try: + value = _sample_attribute(attr, rng, agent, stats) + except FormulaError as e: + raise SamplingError( + f"Agent {index}: Failed to sample '{attr_name}': {e}" + ) from e + else: + # Independent sampling + try: + value = _sample_attribute(attr, rng, agent, stats) + except FormulaError as e: + raise SamplingError( + f"Agent {index}: Failed to sample '{attr_name}': {e}" + ) from e + + value = coerce_to_type(value, attr.type) + value = _apply_hard_constraints(value, attr) + agent[attr_name] = value + _update_stats(attr, value, stats, numeric_values) + + return agent def _sample_single_agent( diff --git a/extropy/population/sampler/households.py b/extropy/population/sampler/households.py new file mode 100644 index 0000000..ec6df45 --- /dev/null +++ b/extropy/population/sampler/households.py @@ -0,0 +1,270 @@ +"""Household-based sampling for co-sampling correlated agent pairs. + +Contains Census-derived household composition rates and correlation tables +for assortative mating on demographics. +""" + +import math +import random +from typing import Any + +from ...core.models.population import Dependent, HouseholdType + + +# ============================================================================= +# Census-derived household composition rates by age bracket of primary adult +# ============================================================================= + +# Keys: age bracket of Adult 1; values: dict of HouseholdType -> probability +HOUSEHOLD_TYPE_WEIGHTS: dict[str, dict[HouseholdType, float]] = { + "18-29": { + HouseholdType.SINGLE: 0.45, + HouseholdType.COUPLE: 0.25, + HouseholdType.SINGLE_PARENT: 0.08, + HouseholdType.COUPLE_WITH_KIDS: 0.15, + HouseholdType.MULTI_GENERATIONAL: 0.07, + }, + "30-44": { + HouseholdType.SINGLE: 0.20, + HouseholdType.COUPLE: 0.15, + HouseholdType.SINGLE_PARENT: 0.12, + HouseholdType.COUPLE_WITH_KIDS: 0.40, + HouseholdType.MULTI_GENERATIONAL: 0.13, + }, + "45-64": { + HouseholdType.SINGLE: 0.25, + HouseholdType.COUPLE: 0.35, + HouseholdType.SINGLE_PARENT: 0.08, + HouseholdType.COUPLE_WITH_KIDS: 0.20, + HouseholdType.MULTI_GENERATIONAL: 0.12, + }, + "65+": { + HouseholdType.SINGLE: 0.35, + HouseholdType.COUPLE: 0.40, + HouseholdType.SINGLE_PARENT: 0.02, + HouseholdType.COUPLE_WITH_KIDS: 0.05, + HouseholdType.MULTI_GENERATIONAL: 0.18, + }, +} + +# Intermarriage rates: probability partner shares same value. +# Key: race_ethnicity group; value: probability of same-race partner. +INTERMARRIAGE_RATES: dict[str, float] = { + "white": 0.90, + "black": 0.82, + "hispanic": 0.78, + "asian": 0.75, + "other": 0.50, +} +_DEFAULT_SAME_RACE_RATE = 0.85 + +# Assortative mating correlation coefficients. +# Higher = more likely partner shares similar value. +ASSORTATIVE_MATING: dict[str, float] = { + "education_level": 0.6, + "religious_affiliation": 0.7, + "political_orientation": 0.5, +} + +# Age gap parameters by gender combination. +# (mean_offset, std) where offset is Adult2_age - Adult1_age. +AGE_GAP_PARAMS: dict[str, tuple[float, float]] = { + "default": (-2.0, 3.0), +} + +# Average household size used for estimating number of households from N. +AVG_HOUSEHOLD_SIZE = 2.5 + + +def _age_bracket(age: int) -> str: + """Map age to bracket key for HOUSEHOLD_TYPE_WEIGHTS.""" + if age < 30: + return "18-29" + elif age < 45: + return "30-44" + elif age < 65: + return "45-64" + else: + return "65+" + + +def sample_household_type(primary_age: int, rng: random.Random) -> HouseholdType: + """Sample a household type based on the primary adult's age bracket.""" + bracket = _age_bracket(primary_age) + weights = HOUSEHOLD_TYPE_WEIGHTS[bracket] + types = list(weights.keys()) + probs = list(weights.values()) + return rng.choices(types, weights=probs, k=1)[0] + + +def household_needs_partner(htype: HouseholdType) -> bool: + """Whether this household type includes a second adult partner.""" + return htype in ( + HouseholdType.COUPLE, + HouseholdType.COUPLE_WITH_KIDS, + HouseholdType.MULTI_GENERATIONAL, + ) + + +def household_needs_kids(htype: HouseholdType) -> bool: + """Whether this household type includes children.""" + return htype in ( + HouseholdType.SINGLE_PARENT, + HouseholdType.COUPLE_WITH_KIDS, + HouseholdType.MULTI_GENERATIONAL, + ) + + +# Attributes that are always shared within a household +HOUSEHOLD_SHARED_ATTRIBUTES = [ + "state", + "urban_rural", + "household_income", + "household_size", +] + +# Attributes correlated between partners (not copied, but biased) +PARTNER_CORRELATED_ATTRIBUTES = [ + "age", + "race_ethnicity", + "education_level", + "religious_affiliation", + "political_orientation", +] + +# Attributes sampled independently for each partner +PARTNER_INDEPENDENT_ATTRIBUTES = [ + "personality", + "occupation_sector", +] + + +def correlate_partner_attribute( + attr_name: str, + primary_value: Any, + rng: random.Random, + available_options: list[str] | None = None, +) -> Any: + """Produce a correlated value for a partner based on the primary's value. + + For categorical attributes, uses assortative mating rates to decide + whether to copy or re-sample. For age, applies a Gaussian offset. + + Returns the correlated value, or None if the attribute isn't in the + correlation tables (caller should sample independently). + """ + if attr_name == "age" and isinstance(primary_value, (int, float)): + mean_offset, std = AGE_GAP_PARAMS.get("default", (-2.0, 3.0)) + partner_age = int(round(rng.gauss(primary_value + mean_offset, std))) + return max(18, partner_age) + + if attr_name == "race_ethnicity": + same_rate = INTERMARRIAGE_RATES.get( + str(primary_value).lower(), _DEFAULT_SAME_RACE_RATE + ) + if rng.random() < same_rate: + return primary_value + # Pick a different value from available options + if available_options: + others = [o for o in available_options if o != primary_value] + if others: + return rng.choice(others) + return primary_value + + if attr_name in ASSORTATIVE_MATING: + correlation = ASSORTATIVE_MATING[attr_name] + if rng.random() < correlation: + return primary_value + # Pick a different value from available options + if available_options: + others = [o for o in available_options if o != primary_value] + if others: + return rng.choice(others) + return primary_value + + return None # Not a correlated attribute + + +def generate_dependents( + household_type: HouseholdType, + household_size: int, + num_adults: int, + primary_age: int, + rng: random.Random, +) -> list[Dependent]: + """Generate NPC dependents for a household. + + Fills the gap between num_adults and household_size with children + or elderly dependents based on household type and primary adult age. + """ + num_dependents = max(0, household_size - num_adults) + if num_dependents == 0: + return [] + + dependents: list[Dependent] = [] + + # Multi-generational households may include an elderly parent + elderly_count = 0 + if household_type == HouseholdType.MULTI_GENERATIONAL and num_dependents > 0: + elderly_count = 1 + elderly_age = primary_age + rng.randint(22, 35) + elderly_gender = rng.choice(["male", "female"]) + relationship = "father" if elderly_gender == "male" else "mother" + dependents.append( + Dependent( + name=f"Dependent ({relationship})", + age=elderly_age, + gender=elderly_gender, + relationship=relationship, + school_status=None, + ) + ) + + # Remaining dependents are children + num_children = num_dependents - elderly_count + for c in range(num_children): + child_age = _sample_child_age(primary_age, rng) + child_gender = rng.choice(["male", "female"]) + relationship = "son" if child_gender == "male" else "daughter" + school_status = _school_status(child_age) + dependents.append( + Dependent( + name=f"Dependent ({relationship})", + age=child_age, + gender=child_gender, + relationship=relationship, + school_status=school_status, + ) + ) + + return dependents + + +def _sample_child_age(parent_age: int, rng: random.Random) -> int: + """Sample a realistic child age given parent age.""" + # Parent had child between age 20-40 typically + max_child_age = max(0, parent_age - 20) + min_child_age = max(0, parent_age - 40) + if max_child_age <= min_child_age: + return max(0, min(17, max_child_age)) + age = rng.randint(min_child_age, min(17, max_child_age)) + return max(0, age) + + +def _school_status(age: int) -> str: + """Determine school status from age.""" + if age < 5: + return "home" + elif age < 11: + return "elementary" + elif age < 14: + return "middle_school" + elif age < 18: + return "high_school" + else: + return "adult" + + +def estimate_household_count(target_agents: int) -> int: + """Estimate number of households needed to produce target_agents individuals.""" + return max(1, math.ceil(target_agents / AVG_HOUSEHOLD_SIZE)) diff --git a/extropy/simulation/reasoning.py b/extropy/simulation/reasoning.py index fcccf56..cd90fd3 100644 --- a/extropy/simulation/reasoning.py +++ b/extropy/simulation/reasoning.py @@ -797,9 +797,7 @@ async def reason_with_pacing( ctx: ReasoningContext, ) -> tuple[int, str, ReasoningResponse | None, float]: start = time.time() - result = await _reason_agent_two_pass_async( - ctx, scenario, config, rate_limiter - ) + result = await _reason_agent_two_pass_async(ctx, scenario, config, rate_limiter) elapsed = time.time() - start completed[0] += 1 diff --git a/extropy/storage/study_db.py b/extropy/storage/study_db.py index 8a9932f..ce333bb 100644 --- a/extropy/storage/study_db.py +++ b/extropy/storage/study_db.py @@ -320,6 +320,18 @@ def init_schema(self) -> None: PRIMARY KEY (session_id, key) ); + CREATE TABLE IF NOT EXISTS households ( + id TEXT NOT NULL, + population_id TEXT NOT NULL, + sample_run_id TEXT NOT NULL, + household_type TEXT, + adult_ids JSON, + dependent_data JSON, + shared_attributes JSON, + PRIMARY KEY (population_id, id) + ); + + CREATE INDEX IF NOT EXISTS idx_households_population ON households(population_id); CREATE INDEX IF NOT EXISTS idx_agents_population ON agents(population_id); CREATE INDEX IF NOT EXISTS idx_network_edges_src ON network_edges(network_id, source_id); CREATE INDEX IF NOT EXISTS idx_network_edges_tgt ON network_edges(network_id, target_id); @@ -428,6 +440,64 @@ def save_sample_result( self.conn.commit() return run_id + def save_households( + self, + population_id: str, + sample_run_id: str, + households: list[dict[str, Any]], + ) -> None: + """Save household records from household-based sampling.""" + if not households: + return + cursor = self.conn.cursor() + cursor.execute( + "DELETE FROM households WHERE population_id = ?", (population_id,) + ) + rows = [] + for hh in households: + rows.append( + ( + hh["id"], + population_id, + sample_run_id, + hh.get("household_type"), + _dumps(hh.get("adult_ids", [])), + _dumps(hh.get("dependent_data", [])), + _dumps(hh.get("shared_attributes", {})), + ) + ) + cursor.executemany( + """ + INSERT INTO households + (id, population_id, sample_run_id, household_type, adult_ids, dependent_data, shared_attributes) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + self.conn.commit() + + def get_households(self, population_id: str) -> list[dict[str, Any]]: + """Load household records for a population.""" + cursor = self.conn.cursor() + cursor.execute( + "SELECT * FROM households WHERE population_id = ? ORDER BY id", + (population_id,), + ) + results = [] + for row in cursor.fetchall(): + results.append( + { + "id": row["id"], + "population_id": row["population_id"], + "sample_run_id": row["sample_run_id"], + "household_type": row["household_type"], + "adult_ids": json.loads(row["adult_ids"]), + "dependent_data": json.loads(row["dependent_data"]), + "shared_attributes": json.loads(row["shared_attributes"]), + } + ) + return results + def get_agents(self, population_id: str) -> list[dict[str, Any]]: cursor = self.conn.cursor() cursor.execute( @@ -504,8 +574,12 @@ def save_network_result( target_id=str(edge.get("target", "")), weight=float(edge.get("weight", 0.0)), edge_type=str(edge.get("type", edge.get("edge_type", "unknown"))), - influence_st=float(infl.get("source_to_target", edge.get("weight", 0.0))), - influence_ts=float(infl.get("target_to_source", edge.get("weight", 0.0))), + influence_st=float( + infl.get("source_to_target", edge.get("weight", 0.0)) + ), + influence_ts=float( + infl.get("target_to_source", edge.get("weight", 0.0)) + ), ) rows.append( ( @@ -611,7 +685,9 @@ def init_network_similarity_job( self.conn.commit() return job - def get_network_similarity_job_signature(self, job_id: str) -> dict[str, Any] | None: + def get_network_similarity_job_signature( + self, job_id: str + ) -> dict[str, Any] | None: cursor = self.conn.cursor() cursor.execute( "SELECT signature_json FROM network_similarity_jobs WHERE job_id = ?", @@ -649,7 +725,8 @@ def list_completed_similarity_chunks(self, job_id: str) -> list[tuple[int, int]] (job_id,), ) return [ - (int(row["chunk_start"]), int(row["chunk_end"])) for row in cursor.fetchall() + (int(row["chunk_start"]), int(row["chunk_end"])) + for row in cursor.fetchall() ] def save_similarity_chunk_rows( @@ -712,9 +789,14 @@ def load_similarity_pairs(self, job_id: str) -> dict[tuple[int, int], float]: "SELECT i, j, sim FROM network_similarity_pairs WHERE job_id = ?", (job_id,), ) - return {(int(row["i"]), int(row["j"])): float(row["sim"]) for row in cursor.fetchall()} + return { + (int(row["i"]), int(row["j"])): float(row["sim"]) + for row in cursor.fetchall() + } - def mark_similarity_job_complete(self, job_id: str, drop_pairs: bool = False) -> None: + def mark_similarity_job_complete( + self, job_id: str, drop_pairs: bool = False + ) -> None: cursor = self.conn.cursor() cursor.execute( """ @@ -725,7 +807,9 @@ def mark_similarity_job_complete(self, job_id: str, drop_pairs: bool = False) -> (_now_iso(), job_id), ) if drop_pairs: - cursor.execute("DELETE FROM network_similarity_pairs WHERE job_id = ?", (job_id,)) + cursor.execute( + "DELETE FROM network_similarity_pairs WHERE job_id = ?", (job_id,) + ) self.conn.commit() def upsert_network_generation_status( @@ -855,7 +939,16 @@ def create_simulation_run( (run_id, scenario_name, population_id, network_id, config_json, seed, status, started_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, - (run_id, scenario_name, population_id, network_id, _dumps(config), seed, status, _now_iso()), + ( + run_id, + scenario_name, + population_id, + network_id, + _dumps(config), + seed, + status, + _now_iso(), + ), ) self.conn.commit() @@ -866,7 +959,9 @@ def update_simulation_run( stopped_reason: str | None = None, ) -> None: cursor = self.conn.cursor() - completed_at = _now_iso() if status in {"completed", "failed", "stopped"} else None + completed_at = ( + _now_iso() if status in {"completed", "failed", "stopped"} else None + ) cursor.execute( """ UPDATE simulation_runs diff --git a/tests/test_engine.py b/tests/test_engine.py index 0df6b60..c188cf5 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -905,7 +905,8 @@ async def fake_batch( return results, BatchTokenUsage() with patch( - "extropy.simulation.engine.batch_reason_agents_async", side_effect=fake_batch + "extropy.simulation.engine.batch_reason_agents_async", + side_effect=fake_batch, ): reasoned, _, _ = engine._reason_agents(0) @@ -1229,7 +1230,8 @@ async def fake_batch( engine.state_manager.record_exposure(aid, exposure) with patch( - "extropy.simulation.engine.batch_reason_agents_async", side_effect=fake_batch + "extropy.simulation.engine.batch_reason_agents_async", + side_effect=fake_batch, ): engine._reason_agents(0) @@ -1283,7 +1285,8 @@ async def fake_batch( return [(ctx.agent_id, resp) for ctx in contexts], BatchTokenUsage() with patch( - "extropy.simulation.engine.batch_reason_agents_async", side_effect=fake_batch + "extropy.simulation.engine.batch_reason_agents_async", + side_effect=fake_batch, ): engine._reason_agents(0) @@ -1492,7 +1495,8 @@ async def fake_batch( return results, usage with patch( - "extropy.simulation.engine.batch_reason_agents_async", side_effect=fake_batch + "extropy.simulation.engine.batch_reason_agents_async", + side_effect=fake_batch, ): engine._reason_agents(0) diff --git a/tests/test_household_sampling.py b/tests/test_household_sampling.py new file mode 100644 index 0000000..d77328b --- /dev/null +++ b/tests/test_household_sampling.py @@ -0,0 +1,394 @@ +"""Tests for household-based sampling (Phase B1).""" + +from extropy.core.models.population import ( + PopulationSpec, + SpecMeta, + GroundingSummary, + AttributeSpec, + SamplingConfig, + GroundingInfo, + NormalDistribution, + CategoricalDistribution, + HouseholdType, + Dependent, + STANDARD_PERSONALITY_ATTRIBUTES, +) +from extropy.population.sampler.core import sample_population +from extropy.population.sampler.households import ( + sample_household_type, + household_needs_partner, + household_needs_kids, + correlate_partner_attribute, + generate_dependents, + estimate_household_count, +) + +import random + + +def _make_household_spec(size: int = 100) -> PopulationSpec: + """Create a minimal spec with household-scoped attributes.""" + return PopulationSpec( + meta=SpecMeta(description="Test household spec", size=size), + grounding=GroundingSummary( + overall="medium", + sources_count=1, + strong_count=2, + medium_count=2, + low_count=1, + ), + attributes=[ + AttributeSpec( + name="age", + type="int", + category="universal", + description="Age", + scope="individual", + sampling=SamplingConfig( + strategy="independent", + distribution=NormalDistribution( + type="normal", mean=40, std=12, min=18, max=85 + ), + ), + grounding=GroundingInfo(level="strong", method="researched"), + ), + AttributeSpec( + name="gender", + type="categorical", + category="universal", + description="Gender", + scope="individual", + sampling=SamplingConfig( + strategy="independent", + distribution=CategoricalDistribution( + type="categorical", + options=["male", "female"], + weights=[0.49, 0.51], + ), + ), + grounding=GroundingInfo(level="strong", method="researched"), + ), + AttributeSpec( + name="state", + type="categorical", + category="universal", + description="State", + scope="household", # shared within household + sampling=SamplingConfig( + strategy="independent", + distribution=CategoricalDistribution( + type="categorical", + options=["CA", "TX", "NY"], + weights=[0.4, 0.3, 0.3], + ), + ), + grounding=GroundingInfo(level="medium", method="estimated"), + ), + AttributeSpec( + name="race_ethnicity", + type="categorical", + category="universal", + description="Race/ethnicity", + scope="individual", + sampling=SamplingConfig( + strategy="independent", + distribution=CategoricalDistribution( + type="categorical", + options=["white", "black", "hispanic", "asian"], + weights=[0.6, 0.13, 0.18, 0.09], + ), + ), + grounding=GroundingInfo(level="medium", method="researched"), + ), + AttributeSpec( + name="education_level", + type="categorical", + category="universal", + description="Education", + scope="individual", + sampling=SamplingConfig( + strategy="independent", + distribution=CategoricalDistribution( + type="categorical", + options=["high_school", "bachelors", "masters", "doctorate"], + weights=[0.4, 0.3, 0.2, 0.1], + ), + ), + grounding=GroundingInfo(level="medium", method="researched"), + ), + ], + sampling_order=["age", "gender", "state", "race_ethnicity", "education_level"], + ) + + +def _make_individual_spec(size: int = 50) -> PopulationSpec: + """Create a spec with NO household-scoped attributes (legacy mode).""" + return PopulationSpec( + meta=SpecMeta(description="Test individual spec", size=size), + grounding=GroundingSummary( + overall="medium", + sources_count=1, + strong_count=1, + medium_count=1, + low_count=0, + ), + attributes=[ + AttributeSpec( + name="age", + type="int", + category="universal", + description="Age", + sampling=SamplingConfig( + strategy="independent", + distribution=NormalDistribution( + type="normal", mean=40, std=12, min=18, max=85 + ), + ), + grounding=GroundingInfo(level="strong", method="researched"), + ), + AttributeSpec( + name="gender", + type="categorical", + category="universal", + description="Gender", + sampling=SamplingConfig( + strategy="independent", + distribution=CategoricalDistribution( + type="categorical", + options=["male", "female"], + weights=[0.49, 0.51], + ), + ), + grounding=GroundingInfo(level="strong", method="researched"), + ), + ], + sampling_order=["age", "gender"], + ) + + +class TestHouseholdModels: + def test_household_type_enum(self): + assert HouseholdType.SINGLE.value == "single" + assert HouseholdType.COUPLE_WITH_KIDS.value == "couple_with_kids" + + def test_dependent_model(self): + dep = Dependent( + name="Child", + age=10, + gender="male", + relationship="son", + school_status="elementary", + ) + assert dep.age == 10 + assert dep.school_status == "elementary" + + def test_standard_personality_attributes(self): + assert "conformity" in STANDARD_PERSONALITY_ATTRIBUTES + assert "extraversion" in STANDARD_PERSONALITY_ATTRIBUTES + assert len(STANDARD_PERSONALITY_ATTRIBUTES) == 6 + + +class TestHouseholdSamplingHelpers: + def test_sample_household_type_returns_valid(self): + rng = random.Random(42) + for age in [22, 35, 55, 70]: + htype = sample_household_type(age, rng) + assert isinstance(htype, HouseholdType) + + def test_household_needs_partner(self): + assert not household_needs_partner(HouseholdType.SINGLE) + assert household_needs_partner(HouseholdType.COUPLE) + assert not household_needs_partner(HouseholdType.SINGLE_PARENT) + assert household_needs_partner(HouseholdType.COUPLE_WITH_KIDS) + assert household_needs_partner(HouseholdType.MULTI_GENERATIONAL) + + def test_household_needs_kids(self): + assert not household_needs_kids(HouseholdType.SINGLE) + assert not household_needs_kids(HouseholdType.COUPLE) + assert household_needs_kids(HouseholdType.SINGLE_PARENT) + assert household_needs_kids(HouseholdType.COUPLE_WITH_KIDS) + assert household_needs_kids(HouseholdType.MULTI_GENERATIONAL) + + def test_correlate_age(self): + rng = random.Random(42) + partner_age = correlate_partner_attribute("age", 35, rng) + assert isinstance(partner_age, int) + assert partner_age >= 18 + + def test_correlate_race_same_rate(self): + rng = random.Random(42) + same_count = 0 + trials = 500 + for _ in range(trials): + result = correlate_partner_attribute( + "race_ethnicity", + "white", + rng, + available_options=["white", "black", "hispanic"], + ) + if result == "white": + same_count += 1 + # Expect ~90% same-race for white + rate = same_count / trials + assert 0.80 < rate < 0.97, f"Same-race rate {rate:.2f} outside expected range" + + def test_correlate_education_assortative(self): + rng = random.Random(42) + same_count = 0 + trials = 500 + for _ in range(trials): + result = correlate_partner_attribute( + "education_level", + "bachelors", + rng, + available_options=["high_school", "bachelors", "masters", "doctorate"], + ) + if result == "bachelors": + same_count += 1 + rate = same_count / trials + # Expect ~60% same education + assert 0.50 < rate < 0.75, f"Assortative rate {rate:.2f} outside expected range" + + def test_correlate_unknown_attribute_returns_none(self): + rng = random.Random(42) + result = correlate_partner_attribute("personality", "introverted", rng) + assert result is None + + def test_generate_dependents_no_kids(self): + rng = random.Random(42) + deps = generate_dependents(HouseholdType.COUPLE, 2, 2, 40, rng) + assert len(deps) == 0 + + def test_generate_dependents_with_kids(self): + rng = random.Random(42) + deps = generate_dependents(HouseholdType.COUPLE_WITH_KIDS, 4, 2, 40, rng) + assert len(deps) == 2 + for d in deps: + assert isinstance(d, Dependent) + assert d.age >= 0 + assert d.relationship in ("son", "daughter") + + def test_generate_dependents_multi_generational(self): + rng = random.Random(42) + deps = generate_dependents(HouseholdType.MULTI_GENERATIONAL, 4, 2, 45, rng) + assert len(deps) == 2 + relationships = [d.relationship for d in deps] + assert any(r in ("father", "mother") for r in relationships) + + def test_estimate_household_count(self): + assert estimate_household_count(100) == 40 + assert estimate_household_count(1) == 1 + + +class TestHouseholdPopulationSampling: + def test_household_ids_assigned(self): + spec = _make_household_spec(size=50) + result = sample_population(spec, count=50, seed=42) + agents = result.agents + agents_with_hh = [a for a in agents if a.get("household_id")] + assert len(agents_with_hh) == len(agents), "All agents should have household_id" + + def test_partner_agents_share_household(self): + spec = _make_household_spec(size=100) + result = sample_population(spec, count=100, seed=42) + agents = result.agents + id_map = {a["_id"]: a for a in agents} + for agent in agents: + pid = agent.get("partner_id") + if pid: + partner = id_map.get(pid) + assert partner is not None, f"Partner {pid} not found" + assert partner["household_id"] == agent["household_id"] + + def test_household_scoped_attrs_shared(self): + spec = _make_household_spec(size=100) + result = sample_population(spec, count=100, seed=42) + agents = result.agents + id_map = {a["_id"]: a for a in agents} + for agent in agents: + pid = agent.get("partner_id") + if pid: + partner = id_map.get(pid) + assert partner is not None + # 'state' is household-scoped, should be shared + assert agent["state"] == partner["state"], ( + f"Household-scoped attr 'state' should match: " + f"{agent['state']} != {partner['state']}" + ) + + def test_dependent_count_matches(self): + spec = _make_household_spec(size=50) + result = sample_population(spec, count=50, seed=42) + for agent in result.agents: + deps = agent.get("dependents", []) + assert isinstance(deps, list) + # Each dependent should have required fields + for d in deps: + assert "age" in d + assert "gender" in d + assert "relationship" in d + + def test_single_households_no_partner(self): + spec = _make_household_spec(size=200) + result = sample_population(spec, count=200, seed=42) + single_agents = [ + a + for a in result.agents + if a.get("household_role") == "adult_primary" + and a.get("partner_id") is None + ] + # There should be some single-person households + assert len(single_agents) > 0 + + def test_total_agent_count_matches_requested(self): + spec = _make_household_spec(size=100) + result = sample_population(spec, count=100, seed=42) + # Should produce at most the requested count + assert len(result.agents) <= 100 + + def test_meta_has_household_info(self): + spec = _make_household_spec(size=50) + result = sample_population(spec, count=50, seed=42) + assert result.meta.get("household_mode") is True + assert "household_count" in result.meta + assert "household_type_distribution" in result.meta + + def test_backward_compat_individual_spec(self): + """Specs without household attributes should sample as before.""" + spec = _make_individual_spec(size=50) + result = sample_population(spec, count=50, seed=42) + assert len(result.agents) == 50 + # No household fields + assert result.agents[0].get("household_id") is None + assert result.meta.get("household_mode") is None + + def test_households_attached_to_result(self): + spec = _make_household_spec(size=50) + result = sample_population(spec, count=50, seed=42) + households = getattr(result, "_households", []) + assert len(households) > 0 + for hh in households: + assert "id" in hh + assert "household_type" in hh + assert "adult_ids" in hh + assert len(hh["adult_ids"]) >= 1 + + +class TestCorrelatedDemographics: + def test_partner_age_correlation(self): + """Partners should have correlated ages (within a few years).""" + spec = _make_household_spec(size=200) + result = sample_population(spec, count=200, seed=42) + id_map = {a["_id"]: a for a in result.agents} + + age_diffs = [] + for agent in result.agents: + pid = agent.get("partner_id") + if pid and pid in id_map: + diff = abs(agent["age"] - id_map[pid]["age"]) + age_diffs.append(diff) + + if age_diffs: + avg_diff = sum(age_diffs) / len(age_diffs) + # Average age gap should be small (< 10 years) + assert avg_diff < 10, f"Average age gap {avg_diff:.1f} too large" diff --git a/tests/test_network.py b/tests/test_network.py index f642594..15c3613 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -219,7 +219,9 @@ def test_get_total_weight_empty(self): def test_quality_profile_defaults_are_applied(self): """Profile defaults should deterministically fill advanced settings.""" - strict = NetworkConfig(quality_profile="strict").apply_quality_profile_defaults() + strict = NetworkConfig( + quality_profile="strict" + ).apply_quality_profile_defaults() fast = NetworkConfig(quality_profile="fast").apply_quality_profile_defaults() assert strict.calibration_restarts >= fast.calibration_restarts assert strict.max_calibration_minutes >= fast.max_calibration_minutes diff --git a/tests/test_structural_edges.py b/tests/test_structural_edges.py new file mode 100644 index 0000000..a8d4974 --- /dev/null +++ b/tests/test_structural_edges.py @@ -0,0 +1,226 @@ +"""Tests for structural edge generation (Phase B2) and identity clustering (Phase B3).""" + +import random + +from extropy.core.models.network import Edge +from extropy.population.network.generator import _generate_structural_edges +from extropy.population.network.config import NetworkConfig + + +def _make_household_agents(n_households: int = 10) -> tuple[list[dict], list[str]]: + """Create test agents with household structure.""" + agents = [] + agent_ids = [] + idx = 0 + + for hh in range(n_households): + hh_id = f"household_{hh:04d}" + # Adult 1 + a1_id = f"agent_{idx:04d}" + a1 = { + "_id": a1_id, + "household_id": hh_id, + "household_role": "adult_primary", + "partner_id": f"agent_{idx + 1:04d}" if hh % 2 == 0 else None, + "age": 35 + hh, + "gender": "male" if hh % 3 != 0 else "female", + "state": "CA" if hh < 5 else "TX", + "urban_rural": "urban" if hh < 7 else "rural", + "occupation_sector": "tech" if hh < 6 else "healthcare", + "religious_affiliation": "christian" if hh < 4 else "none", + "dependents": ( + [ + { + "age": 10, + "gender": "male", + "relationship": "son", + "school_status": "elementary", + } + ] + if hh % 3 == 0 + else [] + ), + } + agents.append(a1) + agent_ids.append(a1_id) + idx += 1 + + # Adult 2 (partner) for even-numbered households + if hh % 2 == 0: + a2_id = f"agent_{idx:04d}" + a2 = { + "_id": a2_id, + "household_id": hh_id, + "household_role": "adult_secondary", + "partner_id": a1_id, + "age": 33 + hh, + "gender": "female" if hh % 3 != 0 else "male", + "state": a1["state"], + "urban_rural": a1["urban_rural"], + "occupation_sector": "education" if hh < 3 else "tech", + "religious_affiliation": a1["religious_affiliation"], + "dependents": a1["dependents"], + } + agents.append(a2) + agent_ids.append(a2_id) + idx += 1 + + return agents, agent_ids + + +class TestStructuralEdgeGeneration: + def test_partner_edges_created(self): + agents, agent_ids = _make_household_agents(10) + rng = random.Random(42) + edges = _generate_structural_edges(agents, agent_ids, rng) + + partner_edges = [e for e in edges if e.edge_type == "partner"] + # Count agents that have a partner_id set + agents_with_partners = [a for a in agents if a.get("partner_id")] + expected_pairs = len(agents_with_partners) // 2 + assert len(partner_edges) == expected_pairs + + def test_partner_edge_weight(self): + agents, agent_ids = _make_household_agents(4) + rng = random.Random(42) + edges = _generate_structural_edges(agents, agent_ids, rng) + partner_edges = [e for e in edges if e.edge_type == "partner"] + for e in partner_edges: + assert e.weight == 1.0 + assert e.structural is True + assert e.context == "household" + + def test_household_edges_connect_members(self): + agents, agent_ids = _make_household_agents(6) + rng = random.Random(42) + edges = _generate_structural_edges(agents, agent_ids, rng) + + household_edges = [e for e in edges if e.edge_type == "household"] + # Household edges connect adults in the same household (weight 0.9) + for e in household_edges: + assert e.weight == 0.9 + assert e.structural is True + + def test_coworker_edges_respect_sector_state(self): + agents, agent_ids = _make_household_agents(10) + rng = random.Random(42) + edges = _generate_structural_edges(agents, agent_ids, rng) + + coworker_edges = [e for e in edges if e.edge_type == "coworker"] + id_to_agent = {a["_id"]: a for a in agents} + for e in coworker_edges: + a = id_to_agent[e.source] + b = id_to_agent[e.target] + assert a["occupation_sector"] == b["occupation_sector"] + assert a["state"] == b["state"] + assert e.weight == 0.6 + assert e.context == "workplace" + + def test_neighbor_edges_respect_age_constraint(self): + agents, agent_ids = _make_household_agents(10) + rng = random.Random(42) + edges = _generate_structural_edges(agents, agent_ids, rng) + + neighbor_edges = [e for e in edges if e.edge_type == "neighbor"] + id_to_agent = {a["_id"]: a for a in agents} + for e in neighbor_edges: + a = id_to_agent[e.source] + b = id_to_agent[e.target] + assert abs(a["age"] - b["age"]) <= 15 + assert a["state"] == b["state"] + assert a["urban_rural"] == b["urban_rural"] + + def test_all_edges_are_structural(self): + agents, agent_ids = _make_household_agents(10) + rng = random.Random(42) + edges = _generate_structural_edges(agents, agent_ids, rng) + for e in edges: + assert e.structural is True + assert e.context is not None + + def test_school_parent_edges(self): + agents, agent_ids = _make_household_agents(10) + rng = random.Random(42) + edges = _generate_structural_edges(agents, agent_ids, rng) + school_edges = [e for e in edges if e.edge_type == "school_parent"] + # Agents with school-age dependents in same state+urban should connect + id_to_agent = {a["_id"]: a for a in agents} + for e in school_edges: + a = id_to_agent[e.source] + b = id_to_agent[e.target] + assert a["state"] == b["state"] + assert a["urban_rural"] == b["urban_rural"] + # Both should have school-age dependents + for agent in (a, b): + has_school = any( + d.get("school_status") + in ("elementary", "middle_school", "high_school") + for d in agent.get("dependents", []) + ) + assert has_school + + def test_no_self_edges(self): + agents, agent_ids = _make_household_agents(10) + rng = random.Random(42) + edges = _generate_structural_edges(agents, agent_ids, rng) + for e in edges: + assert e.source != e.target + + def test_no_duplicate_edges(self): + agents, agent_ids = _make_household_agents(10) + rng = random.Random(42) + edges = _generate_structural_edges(agents, agent_ids, rng) + pairs = set() + for e in edges: + pair = (min(e.source, e.target), max(e.source, e.target)) + assert pair not in pairs, f"Duplicate edge: {pair}" + pairs.add(pair) + + +class TestEdgeModelEnhancements: + def test_edge_structural_default(self): + e = Edge(source="a", target="b", weight=0.5, edge_type="peer") + assert e.structural is False + assert e.context is None + + def test_edge_structural_to_dict(self): + e = Edge( + source="a", + target="b", + weight=0.5, + edge_type="partner", + structural=True, + context="household", + ) + d = e.to_dict() + assert d["structural"] is True + assert d["context"] == "household" + + def test_edge_non_structural_to_dict_omits_fields(self): + e = Edge(source="a", target="b", weight=0.5, edge_type="peer") + d = e.to_dict() + assert "structural" not in d + assert "context" not in d + + +class TestNetworkConfigEnhancements: + def test_degree_distribution_target_default(self): + config = NetworkConfig() + assert config.degree_distribution_target is None + assert config.power_law_exponent == 2.5 + + def test_identity_clustering_defaults(self): + config = NetworkConfig() + assert config.identity_clustering_attributes == [] + assert config.identity_clustering_boost == 1.5 + + def test_identity_clustering_config(self): + config = NetworkConfig( + identity_clustering_attributes=[ + "political_orientation", + "religious_affiliation", + ], + identity_clustering_boost=2.0, + ) + assert len(config.identity_clustering_attributes) == 2 + assert config.identity_clustering_boost == 2.0