Skip to content

Commit b7eaadd

Browse files
committed
data changed tohold on to new and old data as well, scatter in prior, fail capture for many GP instances runnin on one client.
1 parent c6ecf8b commit b7eaadd

11 files changed

Lines changed: 204 additions & 78 deletions

File tree

CLAUDE.md

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,32 @@ Both classes are composed of internal specialist objects created at `__init__` t
3535

3636
| Class | File | Responsibility |
3737
|---|---|---|
38-
| `GPdata` | [gp_data.py](fvgp/gp_data.py) | Data validation, shape tracking, Euclidean vs. non-Euclidean |
39-
| `GPprior` | [gp_prior.py](fvgp/gp_prior.py) | Kernel and mean function; default is anisotropic Matérn with ARD |
38+
| `GPdata` | [gp_data.py](fvgp/gp_data.py) | Data validation, shape tracking, Euclidean vs. non-Euclidean. Sole source of truth for `x_data`, `y_data`, `noise_variances`, plus the pre-append snapshot (`x_old`, `y_old`, `noise_variances_old`) and last-appended chunk (`x_new`, `y_new`, `noise_variances_new`) |
39+
| `GPprior` | [gp_prior.py](fvgp/gp_prior.py) | Kernel and mean function (default: anisotropic Matérn with ARD). In gp2Scale mode also owns `x_data_scatter_future` (the persistent dask scatter of `x_data`) |
4040
| `GPlikelihood` | [gp_likelihood.py](fvgp/gp_likelihood.py) | Noise model (variances or callable) |
4141
| `GPkv` | [gp_kv.py](fvgp/gp_kv.py) | Owns K+V matrix state and all factorizations; dispatches solves/logdets across linalg modes |
4242
| `GPMarginalLikelihood` | [gp_marginal_likelihood.py](fvgp/gp_marginal_likelihood.py) | Log marginal likelihood and its gradient; delegates factorization to `GPkv` |
4343
| `GPposterior` | [gp_posterior.py](fvgp/gp_posterior.py) | Posterior mean/covariance; information-theoretic quantities |
4444
| `GPtraining` | [gp_training.py](fvgp/gp_training.py) | Hyperparameter optimization (scipy, hgdl async, MCMC, Adam) |
4545

46+
### State propagation
47+
48+
Sources of truth: `GPtraining.hyperparameters` and `GPdata.x_data` / `y_data` / `noise_variances`. Everywhere else reads these via `@property`. Cached state that must be invalidated on a change:
49+
50+
| Mutator | What's refreshed |
51+
|---|---|
52+
| `GP.set_hyperparameters(hps)` | `trainer.hyperparameters``prior.update_state_hyperparameters()` (recomputes `m`, `K`) → `likelihood.update_state()` (`V`) → `kv.update_state_hyperparameters()` (factorization + `KVinvY`) |
53+
| `GP.update_gp_data(..., append=True)` | `data.update()` snapshots `x_old`/`y_old`/etc. → `prior.augment_state_data()` (rank-n update of `m`, `K`) → `likelihood.update_state()``kv.update_state_data(rank_n_update)` |
54+
| `GP.update_gp_data(..., append=False)` | `data.update()` clears `_old`/`_new` slots → `prior.update_state_data()` (full recompute) → `likelihood.update_state()``kv.update_state_data(rank_n_update)` |
55+
| `GP.train(...)` (sync) / `GP.update_hyperparameters(opt_obj)` (async) | both end with `set_hyperparameters(...)` |
56+
57+
`GPposterior` and `GPMarginalLikelihood` hold **no cached state** — every read goes through properties, so they're automatically consistent.
58+
59+
Gotchas:
60+
- **`GP.set_args(new_args)` does NOT invalidate `K`, `m`, `V`, or factorizations.** If `args` flows into a user kernel/mean/noise callable, new args take effect only on the next `set_hyperparameters`, `update_gp_data(append=False)`, fresh `train`, or posterior call with explicit `hyperparameters=`. To force a flush: `set_hyperparameters(self.hyperparameters)`.
61+
- **`update_gp_data(append=False, rank_n_update=True)`** is invalid (the previous factorization is for data that no longer exists); `GP.update_gp_data` emits a `UserWarning` and forces `rank_n_update=False`.
62+
- **`kv.solve(b, x0=...)`** zero-pads `x0` along axis 0 when shapes don't match, so a pre-append `KVinvY` can warm-start the post-append solve in iterative modes (sparseCG/MINRES/preconditioned variants). See [gp_kv.py:333-342](fvgp/gp_kv.py#L333-L342).
63+
4664
### Key supporting modules
4765

4866
- **[gp_lin_alg.py](fvgp/gp_lin_alg.py)** — CPU/GPU linear algebra primitives; Cholesky, LU, sparse solvers; defines `NonPositiveDefiniteError`
@@ -55,6 +73,25 @@ Both classes are composed of internal specialist objects created at `__init__` t
5573

5674
When `gp2Scale=True`, `GP` switches to a Wendland (compactly supported) kernel producing sparse covariance matrices and uses Dask for distributed computation. This path requires a Dask client to be passed in and uses sparse linear solvers instead of dense Cholesky.
5775

76+
**Scatter ownership and lifecycle:**
77+
78+
- `GPprior.x_data_scatter_future` is the single persistent dask scatter of the current `x_data`. Scattered once at `GPprior.__init__` (see [gp_prior.py:93-96](fvgp/gp_prior.py#L93-L96)).
79+
- `GPdata` does NOT scatter — it's pure-Python data only.
80+
- `_compute_prior_covariance_gp2Scale` reads `self.x_data_scatter_future` directly; **no scatter per call**, so training stays dask-quiet.
81+
- On data changes, `augment_state_data` / `update_state_data` refresh the scatter by **overwriting** `self.x_data_scatter_future` (no explicit `release()`). The old future loses its only Python ref and is cleaned up via `__del__`. Calling `release()` explicitly schedules a `_dec_ref` that races against subsequent scatter `replicate` operations in the scheduler — don't do it.
82+
- `_update_prior_covariance_gp2Scale` (the augment path) uses `self.x_data_scatter_future` for the `x_old` side (no content-hash collision since it shares the existing key) and scatters only `x_new` locally, releasing that local future at the end.
83+
84+
**Cross-instance race guard:** [gp.py:14-21](fvgp/gp.py#L14-L21) defines `_GP_INSTANCES_PER_CLIENT`, a `WeakValueDictionary` keyed by `dask_client.id`. `GP.__init__` ([gp.py:285-303](fvgp/gp.py#L285-L303)) raises with a descriptive remediation message if you try to construct a second gp2Scale `GP` on a client that already has a live one — that pattern reliably triggers `FutureCancelledError`/`KeyError` from the scheduler. To reuse a client for a sequence of GPs:
85+
86+
```python
87+
import gc
88+
del previous_gp
89+
gc.collect()
90+
client.run(lambda: None) # flush pending releases
91+
```
92+
93+
The `test_gp2Scale` test uses exactly this pattern between linalg-mode iterations.
94+
5895
### Customization API
5996

6097
Kernels, mean functions, and noise models are all plain Python callables with standardized signatures. Users pass them as arguments to `GP`/`fvGP` constructors. The full hyperparameter vector is shared across kernel, mean, and noise callables, but each callable must only read its reserved index range. Kernel gradients can be user-supplied or computed via finite differences.

examples/gp2ScaleTest.ipynb

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
},
9999
{
100100
"cell_type": "code",
101-
"execution_count": null,
101+
"execution_count": 8,
102102
"id": "fe1f017a",
103103
"metadata": {},
104104
"outputs": [
@@ -111,8 +111,22 @@
111111
"Finished 20 out of 100 iterations. f(x)= -12722.285286212222\n",
112112
"Finished 30 out of 100 iterations. f(x)= -12485.189964773206\n",
113113
"Finished 40 out of 100 iterations. f(x)= -12473.181834409832\n",
114-
"Finished 50 out of 100 iterations. f(x)= -12473.181834409832\n"
114+
"Finished 50 out of 100 iterations. f(x)= -12473.181834409832\n",
115+
"Finished 60 out of 100 iterations. f(x)= -12466.485887381574\n",
116+
"Finished 70 out of 100 iterations. f(x)= -12460.7203909633\n",
117+
"Finished 80 out of 100 iterations. f(x)= -12460.7203909633\n",
118+
"Finished 90 out of 100 iterations. f(x)= -12460.7203909633\n"
115119
]
120+
},
121+
{
122+
"data": {
123+
"text/plain": [
124+
"array([0.14048159, 0.03980175])"
125+
]
126+
},
127+
"execution_count": 8,
128+
"metadata": {},
129+
"output_type": "execute_result"
116130
}
117131
],
118132
"source": [

fvgp/fvgp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,12 @@ class fvGP(GP):
154154
If no kernel is provided, the ``compute_device`` option should be revisited.
155155
The default kernel will use the specified device to compute covariances.
156156
The default is False.
157+
gp2Scale_batch_size : int, optional
158+
Matrix batch size for distributed computing in gp2Scale. The default is 10000.
157159
dask_client : dask.distributed.Client, optional
158160
A dask client for gp2Scale, asynchronous training,a nd certain linear algebra operations.
159161
On HPC architecture, this client is provided by the job script. Please have a look at the examples.
160162
A local client is used as the default.
161-
gp2Scale_batch_size : int, optional
162-
Matrix batch size for distributed computing in gp2Scale. The default is 10000.
163163
linalg_mode : str, optional
164164
Controls the linear-algebra backend used to solve (K+V)x=b and compute log|K+V|.
165165
The default is ``None``, which selects ``"Chol"`` for standard GPs and automatically

fvgp/gp.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
import weakref
23
import numpy as np
34
from loguru import logger
45
from distributed import Client
@@ -13,6 +14,12 @@
1314
import importlib
1415
warnings.simplefilter("once", UserWarning)
1516

17+
# Tracks live GP instances per dask client (gp2Scale mode only). Used to detect
18+
# the case where a user creates a second GP on a client that still has a live GP,
19+
# which triggers race conditions between the new init scatter and the pending
20+
# `_dec_ref` callbacks from the previous GP's scatter activity.
21+
_GP_INSTANCES_PER_CLIENT = weakref.WeakValueDictionary()
22+
1623
# TODO: also search below "TODO"
1724
# Appends and rank_n_updates for gp2Scale are not yet fully tested. Have to check the compute graph and test (what does rank_n_update even mean for the different modes? ).
1825

@@ -147,12 +154,12 @@ class GP:
147154
If no kernel is provided, the ``compute_device`` option should be revisited.
148155
The default kernel will use the specified device to compute covariances.
149156
The default is False.
157+
gp2Scale_batch_size : int, optional
158+
Matrix batch size for distributed computing in gp2Scale. The default is 10000.
150159
dask_client : dask.distributed.Client, optional
151160
A dask client for gp2Scale, asynchronous training,a nd certain linear algebra operations.
152161
On HPC architecture, this client is provided by the job script. Please have a look at the examples.
153162
A local client is used as the default.
154-
gp2Scale_batch_size : int, optional
155-
Matrix batch size for distributed computing in gp2Scale. The default is 10000.
156163
linalg_mode : str, optional
157164
Controls the linear-algebra backend used to solve (K+V)x=b and compute log|K+V|.
158165
The default is ``None``, which selects ``"Chol"`` for standard GPs and automatically
@@ -274,6 +281,27 @@ def __init__(
274281
# Check gp2Scale
275282
dask_client = self.initialize_gp2Scale_dask_client(gp2Scale, dask_client)
276283

284+
# Race-condition guard: in gp2Scale mode, only one GP can be alive per dask
285+
# client. Sharing a client between live GPs causes the new GP's init scatter
286+
# to race against the previous GP's pending `_dec_ref` callbacks, surfacing as
287+
# `FutureCancelledError` or `KeyError` from the scheduler.
288+
if gp2Scale and dask_client is not None:
289+
existing = _GP_INSTANCES_PER_CLIENT.get(dask_client.id)
290+
if existing is not None and existing is not self:
291+
raise Exception(
292+
f"Another GP instance is already active on this dask client "
293+
f"(client.id={dask_client.id!r}). Sharing a dask client between "
294+
f"multiple live GPs in gp2Scale mode triggers race conditions "
295+
f"in the scheduler's scatter reference counting.\n"
296+
f"To reuse the same client for a sequence of GPs, destroy the "
297+
f"previous one first:\n"
298+
f" import gc\n"
299+
f" del previous_gp\n"
300+
f" gc.collect()\n"
301+
f" client.run(lambda: None) # flush pending releases\n"
302+
f"Or use a fresh dask client per GP."
303+
)
304+
277305
########################################
278306
###init data instance [tier 1]##########
279307
########################################
@@ -359,6 +387,11 @@ def __init__(
359387
self.kv,
360388
self.likelihood)
361389

390+
# Register this instance for the cross-instance race-condition guard above.
391+
# Entry is removed automatically when self is garbage-collected.
392+
if gp2Scale and dask_client is not None:
393+
_GP_INSTANCES_PER_CLIENT[dask_client.id] = self
394+
362395
#########PROPERTIES#########################################
363396
@property
364397
def x_data(self):
@@ -499,7 +532,6 @@ def update_gp_data(
499532
assert isinstance(noise_variances_new, np.ndarray) or noise_variances_new is None, \
500533
"wrong format in noise_variances_new"
501534
assert len(x_new) == len(y_new), "updated x and y do not have the same lengths."
502-
old_x_data = self.x_data.copy()
503535
if rank_n_update is None: rank_n_update = append
504536
if not append and rank_n_update:
505537
warnings.warn("`rank_n_update=True` is invalid when `append=False` "
@@ -510,10 +542,8 @@ def update_gp_data(
510542
self.data.update(x_new, y_new, noise_variances_new, append=append)
511543

512544
# update prior
513-
if append:
514-
self.prior.augment_state_data(old_x_data, x_new)
515-
else:
516-
self.prior.update_state_data()
545+
if append: self.prior.augment_state_data()
546+
else:self.prior.update_state_data()
517547

518548
# update likelihood
519549
self.likelihood.update_state()

fvgp/gp_data.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def __init__(self, x_data, y_data,
3737
self.x_data = x_data
3838
self.y_data = y_data
3939
self.noise_variances = noise_variances
40+
self.x_new = None
41+
self.y_new = None
42+
self.noise_variances_new = None
43+
self.x_old = None
44+
self.y_old = None
45+
self.noise_variances_old = None
4046
self.point_number = len(self.x_data)
4147
self._check_for_nan()
4248
self.fvgp_x_data = None
@@ -48,12 +54,9 @@ def __init__(self, x_data, y_data,
4854
self.gp2Scale = gp2Scale
4955
self.compute_device = compute_device
5056
self.dask_client = dask_client
51-
self.x_data_scatter_future = None
5257
self.compute_workers = []
5358
if gp2Scale and dask_client is not None:
5459
self.compute_workers = list(dask_client.scheduler_info()["workers"].keys())
55-
self.x_data_scatter_future = dask_client.scatter(
56-
self.x_data, workers=self.compute_workers, broadcast=True, direct=True)
5760

5861
def set_fvgp_data(self, fvgp_x_data, fvgp_y_data, fvgp_noise_variances, x_out):
5962
self.fvgp_x_data = fvgp_x_data
@@ -91,19 +94,28 @@ def update(self, x_data_new, y_data_new, noise_variances_new=None, append=True):
9194
self.x_data = x_data_new
9295
self.y_data = y_data_new
9396
self.noise_variances = noise_variances_new
97+
self.x_old = None
98+
self.y_old = None
99+
self.noise_variances_old = None
100+
self.x_new = None
101+
self.y_new = None
102+
self.noise_variances_new = None
94103
else:
104+
self.x_old = self.x_data
105+
self.y_old = self.y_data
106+
self.noise_variances_old = self.noise_variances
107+
self.x_new = x_data_new
108+
self.y_new = y_data_new
109+
self.noise_variances_new = noise_variances_new
95110
if self.Euclidean: self.x_data = np.vstack([self.x_data, x_data_new])
96111
else: self.x_data = self.x_data + x_data_new
97112
self.y_data = np.vstack([self.y_data, y_data_new])
98113
if isinstance(noise_variances_new, np.ndarray):
99114
self.noise_variances = np.append(self.noise_variances, noise_variances_new)
100115
self.point_number = len(self.x_data)
101116
self._check_for_nan()
102-
if not append and self.gp2Scale and self.dask_client is not None:
103-
if self.x_data_scatter_future is not None:
104-
self.x_data_scatter_future.release()
105-
self.x_data_scatter_future = self.dask_client.scatter(
106-
self.x_data, workers=self.compute_workers, broadcast=True, direct=True)
117+
118+
107119

108120
def _check_for_nan(self):
109121
if self.Euclidean:
@@ -113,9 +125,15 @@ def __getstate__(self):
113125
state = dict(
114126
x_data=self.x_data,
115127
y_data=self.y_data,
128+
noise_variances=self.noise_variances,
116129
Euclidean=self.Euclidean,
117130
index_set_dim=self.index_set_dim,
118-
noise_variances=self.noise_variances,
131+
x_new=self.x_new,
132+
y_new=self.y_new,
133+
noise_variances_new=self.noise_variances_new,
134+
x_old=self.x_old,
135+
y_old=self.y_old,
136+
noise_variances_old=self.noise_variances_old,
119137
point_number=self.point_number,
120138
fvgp_x_data=self.fvgp_x_data,
121139
fvgp_y_data=self.fvgp_y_data,
@@ -127,7 +145,6 @@ def __getstate__(self):
127145
gp2Scale=self.gp2Scale,
128146
compute_device=self.compute_device,
129147
dask_client=None,
130-
x_data_scatter_future=None,
131148
compute_workers=self.compute_workers,
132149
)
133150
return state

0 commit comments

Comments
 (0)