|
| 1 | +import logging |
1 | 2 | from contextlib import ExitStack as does_not_raise |
2 | 3 | from typing import Any |
3 | 4 | from unittest.mock import patch |
|
29 | 30 | Field, |
30 | 31 | GenDataConfig, |
31 | 32 | GenKwConfig, |
| 33 | + LocalizationType, |
32 | 34 | ObservationSettings, |
33 | 35 | OutlierSettings, |
34 | 36 | ) |
@@ -510,6 +512,103 @@ def test_that_alpha_can_be_used_for_outlier_detection( |
510 | 512 | ) |
511 | 513 |
|
512 | 514 |
|
| 515 | +@pytest.mark.parametrize( |
| 516 | + "update_strategy", |
| 517 | + [LocalizationType.ADAPTIVE, LocalizationType.GLOBAL], |
| 518 | +) |
| 519 | +def test_that_constant_parameter_is_skipped_with_warning_and_carried_over( |
| 520 | + storage, obs, caplog, update_strategy |
| 521 | +): |
| 522 | + """A parameter that is constant across all realizations has zero variance, so |
| 523 | + no parameters are selected for update. Adaptive localization used to crash in |
| 524 | + this case (calling ``assimilate_batch`` with an empty array). Regardless of the |
| 525 | + update strategy, the group should now be skipped with a warning and carried |
| 526 | + over to the posterior unchanged. |
| 527 | + """ |
| 528 | + constant_parameter = GenKwConfig( |
| 529 | + name="KEY_1", |
| 530 | + group="PARAMETER", |
| 531 | + distribution={"name": "uniform", "min": 0, "max": 1}, |
| 532 | + update_strategy=update_strategy, |
| 533 | + ).model_dump(mode="json") |
| 534 | + response_config = GenDataConfig(keys=["RESPONSE"]).model_dump(mode="json") |
| 535 | + experiment = storage.create_experiment( |
| 536 | + name="constant_param", |
| 537 | + experiment_config={ |
| 538 | + "parameter_configuration": [constant_parameter], |
| 539 | + "response_configuration": [response_config], |
| 540 | + "observations": obs, |
| 541 | + }, |
| 542 | + ) |
| 543 | + prior_storage = storage.create_ensemble( |
| 544 | + experiment, |
| 545 | + ensemble_size=10, |
| 546 | + iteration=0, |
| 547 | + name="prior", |
| 548 | + ) |
| 549 | + rng = np.random.default_rng(1234) |
| 550 | + |
| 551 | + prior_storage.save_parameters( |
| 552 | + dataset=pl.concat( |
| 553 | + [ |
| 554 | + pl.DataFrame({"KEY_1": [0.5], "realization": iens}) |
| 555 | + for iens in range(prior_storage.ensemble_size) |
| 556 | + ], |
| 557 | + how="vertical", |
| 558 | + ) |
| 559 | + ) |
| 560 | + |
| 561 | + for iens in range(prior_storage.ensemble_size): |
| 562 | + values = rng.uniform(0.8, 1, 3) |
| 563 | + prior_storage.save_response( |
| 564 | + "gen_data", |
| 565 | + pl.DataFrame( |
| 566 | + { |
| 567 | + "response_key": "RESPONSE", |
| 568 | + "report_step": pl.Series(np.full(len(values), 0), dtype=pl.UInt16), |
| 569 | + "index": pl.Series(range(len(values)), dtype=pl.UInt16), |
| 570 | + "values": values, |
| 571 | + } |
| 572 | + ), |
| 573 | + iens, |
| 574 | + ) |
| 575 | + |
| 576 | + posterior_storage = storage.create_ensemble( |
| 577 | + prior_storage.experiment_id, |
| 578 | + ensemble_size=prior_storage.ensemble_size, |
| 579 | + iteration=1, |
| 580 | + name="posterior", |
| 581 | + prior_ensemble=prior_storage, |
| 582 | + ) |
| 583 | + |
| 584 | + es_settings = ESSettings() |
| 585 | + strategy_map = build_strategy_map( |
| 586 | + parameters=["KEY_1"], |
| 587 | + param_configs=prior_storage.experiment.parameter_configuration, |
| 588 | + enkf_truncation=es_settings.enkf_truncation, |
| 589 | + correlation_threshold=es_settings.correlation_threshold, |
| 590 | + ) |
| 591 | + with caplog.at_level(logging.WARNING): |
| 592 | + smoother_update( |
| 593 | + prior_storage, |
| 594 | + posterior_storage, |
| 595 | + observations=["OBSERVATION"], |
| 596 | + update_settings=ObservationSettings(), |
| 597 | + rng=rng, |
| 598 | + strategy_map=strategy_map, |
| 599 | + ) |
| 600 | + |
| 601 | + assert "have 0 variance across realizations and will not be updated" in caplog.text |
| 602 | + |
| 603 | + prior_values = prior_storage.load_parameters_numpy( |
| 604 | + "KEY_1", np.arange(prior_storage.ensemble_size) |
| 605 | + ) |
| 606 | + posterior_values = posterior_storage.load_parameters_numpy( |
| 607 | + "KEY_1", np.arange(prior_storage.ensemble_size) |
| 608 | + ) |
| 609 | + assert np.array_equal(prior_values, posterior_values) |
| 610 | + |
| 611 | + |
513 | 612 | @pytest.mark.slow |
514 | 613 | def test_update_only_using_subset_observations( |
515 | 614 | snake_oil_case_storage, snake_oil_storage, snapshot |
|
0 commit comments