You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: CLAUDE.md
+39-2Lines changed: 39 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -35,14 +35,32 @@ Both classes are composed of internal specialist objects created at `__init__` t
35
35
36
36
| Class | File | Responsibility |
37
37
|---|---|---|
38
-
|`GPdata`|[gp_data.py](fvgp/gp_data.py)| Data validation, shape tracking, Euclidean vs. non-Euclidean |
39
-
|`GPprior`|[gp_prior.py](fvgp/gp_prior.py)| Kernel and mean function; default is anisotropic Matérn with ARD |
38
+
|`GPdata`|[gp_data.py](fvgp/gp_data.py)| Data validation, shape tracking, Euclidean vs. non-Euclidean. Sole source of truth for `x_data`, `y_data`, `noise_variances`, plus the pre-append snapshot (`x_old`, `y_old`, `noise_variances_old`) and last-appended chunk (`x_new`, `y_new`, `noise_variances_new`)|
39
+
|`GPprior`|[gp_prior.py](fvgp/gp_prior.py)| Kernel and mean function (default: anisotropic Matérn with ARD). In gp2Scale mode also owns `x_data_scatter_future` (the persistent dask scatter of `x_data`)|
40
40
|`GPlikelihood`|[gp_likelihood.py](fvgp/gp_likelihood.py)| Noise model (variances or callable) |
41
41
|`GPkv`|[gp_kv.py](fvgp/gp_kv.py)| Owns K+V matrix state and all factorizations; dispatches solves/logdets across linalg modes |
42
42
|`GPMarginalLikelihood`|[gp_marginal_likelihood.py](fvgp/gp_marginal_likelihood.py)| Log marginal likelihood and its gradient; delegates factorization to `GPkv`|
Sources of truth: `GPtraining.hyperparameters` and `GPdata.x_data` / `y_data` / `noise_variances`. Everywhere else reads these via `@property`. Cached state that must be invalidated on a change:
|`GP.train(...)` (sync) / `GP.update_hyperparameters(opt_obj)` (async) | both end with `set_hyperparameters(...)`|
56
+
57
+
`GPposterior` and `GPMarginalLikelihood` hold **no cached state** — every read goes through properties, so they're automatically consistent.
58
+
59
+
Gotchas:
60
+
-**`GP.set_args(new_args)` does NOT invalidate `K`, `m`, `V`, or factorizations.** If `args` flows into a user kernel/mean/noise callable, new args take effect only on the next `set_hyperparameters`, `update_gp_data(append=False)`, fresh `train`, or posterior call with explicit `hyperparameters=`. To force a flush: `set_hyperparameters(self.hyperparameters)`.
61
+
-**`update_gp_data(append=False, rank_n_update=True)`** is invalid (the previous factorization is for data that no longer exists); `GP.update_gp_data` emits a `UserWarning` and forces `rank_n_update=False`.
62
+
-**`kv.solve(b, x0=...)`** zero-pads `x0` along axis 0 when shapes don't match, so a pre-append `KVinvY` can warm-start the post-append solve in iterative modes (sparseCG/MINRES/preconditioned variants). See [gp_kv.py:333-342](fvgp/gp_kv.py#L333-L342).
@@ -55,6 +73,25 @@ Both classes are composed of internal specialist objects created at `__init__` t
55
73
56
74
When `gp2Scale=True`, `GP` switches to a Wendland (compactly supported) kernel producing sparse covariance matrices and uses Dask for distributed computation. This path requires a Dask client to be passed in and uses sparse linear solvers instead of dense Cholesky.
57
75
76
+
**Scatter ownership and lifecycle:**
77
+
78
+
-`GPprior.x_data_scatter_future` is the single persistent dask scatter of the current `x_data`. Scattered once at `GPprior.__init__` (see [gp_prior.py:93-96](fvgp/gp_prior.py#L93-L96)).
79
+
-`GPdata` does NOT scatter — it's pure-Python data only.
80
+
-`_compute_prior_covariance_gp2Scale` reads `self.x_data_scatter_future` directly; **no scatter per call**, so training stays dask-quiet.
81
+
- On data changes, `augment_state_data` / `update_state_data` refresh the scatter by **overwriting**`self.x_data_scatter_future` (no explicit `release()`). The old future loses its only Python ref and is cleaned up via `__del__`. Calling `release()` explicitly schedules a `_dec_ref` that races against subsequent scatter `replicate` operations in the scheduler — don't do it.
82
+
-`_update_prior_covariance_gp2Scale` (the augment path) uses `self.x_data_scatter_future` for the `x_old` side (no content-hash collision since it shares the existing key) and scatters only `x_new` locally, releasing that local future at the end.
83
+
84
+
**Cross-instance race guard:**[gp.py:14-21](fvgp/gp.py#L14-L21) defines `_GP_INSTANCES_PER_CLIENT`, a `WeakValueDictionary` keyed by `dask_client.id`. `GP.__init__` ([gp.py:285-303](fvgp/gp.py#L285-L303)) raises with a descriptive remediation message if you try to construct a second gp2Scale `GP` on a client that already has a live one — that pattern reliably triggers `FutureCancelledError`/`KeyError` from the scheduler. To reuse a client for a sequence of GPs:
85
+
86
+
```python
87
+
import gc
88
+
del previous_gp
89
+
gc.collect()
90
+
client.run(lambda: None) # flush pending releases
91
+
```
92
+
93
+
The `test_gp2Scale` test uses exactly this pattern between linalg-mode iterations.
94
+
58
95
### Customization API
59
96
60
97
Kernels, mean functions, and noise models are all plain Python callables with standardized signatures. Users pass them as arguments to `GP`/`fvGP` constructors. The full hyperparameter vector is shared across kernel, mean, and noise callables, but each callable must only read its reserved index range. Kernel gradients can be user-supplied or computed via finite differences.
# Appends and rank_n_updates for gp2Scale are not yet fully tested. Have to check the compute graph and test (what does rank_n_update even mean for the different modes? ).
18
25
@@ -147,12 +154,12 @@ class GP:
147
154
If no kernel is provided, the ``compute_device`` option should be revisited.
148
155
The default kernel will use the specified device to compute covariances.
149
156
The default is False.
157
+
gp2Scale_batch_size : int, optional
158
+
Matrix batch size for distributed computing in gp2Scale. The default is 10000.
150
159
dask_client : dask.distributed.Client, optional
151
160
A dask client for gp2Scale, asynchronous training,a nd certain linear algebra operations.
152
161
On HPC architecture, this client is provided by the job script. Please have a look at the examples.
153
162
A local client is used as the default.
154
-
gp2Scale_batch_size : int, optional
155
-
Matrix batch size for distributed computing in gp2Scale. The default is 10000.
156
163
linalg_mode : str, optional
157
164
Controls the linear-algebra backend used to solve (K+V)x=b and compute log|K+V|.
158
165
The default is ``None``, which selects ``"Chol"`` for standard GPs and automatically
0 commit comments