|
40 | 40 | ) |
41 | 41 |
|
42 | 42 |
|
| 43 | +def merge_env_stat( |
| 44 | + base_obj: Union["Descriptor", "DescriptorBlock"], |
| 45 | + link_obj: Union["Descriptor", "DescriptorBlock"], |
| 46 | + model_prob: float = 1.0, |
| 47 | +) -> None: |
| 48 | + """Merge descriptor env mat stats from link_obj into base_obj. |
| 49 | +
|
| 50 | + Uses probability-weighted merging: merged = base_stats + link_stats * model_prob, |
| 51 | + where model_prob = link_prob / base_prob. |
| 52 | + Mutates base_obj.stats for chaining (3+ models). |
| 53 | +
|
| 54 | + Parameters |
| 55 | + ---------- |
| 56 | + base_obj : Descriptor or DescriptorBlock |
| 57 | + The base descriptor whose stats will be updated. |
| 58 | + link_obj : Descriptor or DescriptorBlock |
| 59 | + The linked descriptor whose stats will be merged in. |
| 60 | + model_prob : float |
| 61 | + The probability weight ratio (link_prob / base_prob). |
| 62 | + """ |
| 63 | + if ( |
| 64 | + getattr(base_obj, "stats", None) is None |
| 65 | + or getattr(link_obj, "stats", None) is None |
| 66 | + ): |
| 67 | + return |
| 68 | + if getattr(base_obj, "set_stddev_constant", False) and getattr( |
| 69 | + base_obj, "set_davg_zero", False |
| 70 | + ): |
| 71 | + return |
| 72 | + |
| 73 | + # Weighted merge of StatItem objects |
| 74 | + base_stats = base_obj.stats |
| 75 | + link_stats = link_obj.stats |
| 76 | + merged_stats = {} |
| 77 | + for kk in base_stats: |
| 78 | + merged_stats[kk] = base_stats[kk] + link_stats[kk] * model_prob |
| 79 | + |
| 80 | + # Compute mean/stddev from merged stats |
| 81 | + base_env = EnvMatStatSe(base_obj) |
| 82 | + base_env.stats = merged_stats |
| 83 | + mean, stddev = base_env() |
| 84 | + |
| 85 | + # Update base_obj stats for chaining |
| 86 | + base_obj.stats = merged_stats |
| 87 | + |
| 88 | + # Update buffers in-place: davg/dstd (simple) or mean/stddev (blocks) |
| 89 | + # mean/stddev are numpy arrays; convert to match the buffer's backend |
| 90 | + if hasattr(base_obj, "davg"): |
| 91 | + xp = array_api_compat.array_namespace(base_obj.dstd) |
| 92 | + device = array_api_compat.device(base_obj.dstd) |
| 93 | + if not getattr(base_obj, "set_davg_zero", False): |
| 94 | + base_obj.davg[...] = xp.asarray( |
| 95 | + mean, dtype=base_obj.davg.dtype, device=device |
| 96 | + ) |
| 97 | + base_obj.dstd[...] = xp.asarray( |
| 98 | + stddev, dtype=base_obj.dstd.dtype, device=device |
| 99 | + ) |
| 100 | + elif hasattr(base_obj, "mean"): |
| 101 | + xp = array_api_compat.array_namespace(base_obj.stddev) |
| 102 | + device = array_api_compat.device(base_obj.stddev) |
| 103 | + if not getattr(base_obj, "set_davg_zero", False): |
| 104 | + base_obj.mean[...] = xp.asarray( |
| 105 | + mean, dtype=base_obj.mean.dtype, device=device |
| 106 | + ) |
| 107 | + base_obj.stddev[...] = xp.asarray( |
| 108 | + stddev, dtype=base_obj.stddev.dtype, device=device |
| 109 | + ) |
| 110 | + |
| 111 | + |
43 | 112 | class EnvMatStat(BaseEnvMatStat): |
44 | 113 | def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]: |
45 | 114 | """Compute the statistics of the environment matrix for a single system. |
|
0 commit comments