Skip to content

Commit c05f06a

Browse files
yaugenst-flexmahlau-flex
authored andcommitted
fix: stabilize model hashing in tests
1 parent 9448637 commit c05f06a

2 files changed

Lines changed: 52 additions & 43 deletions

File tree

tests/test_web/test_tidy3d_stub.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -118,53 +118,57 @@ def test_stub_data_to_file(tmp_path):
118118
def test_stub_data_postprocess_logs(tmp_path):
119119
"""Tests the postprocess method of Tidy3dStubData when simulation diverged."""
120120
td.log.set_capture(True)
121-
122-
# test diverged
123-
sim_data = make_sim_data()
124-
sim_data = sim_data.updated_copy(diverged=True, log="The simulation has diverged!")
125-
file_path = os.path.join(tmp_path, "test_diverged.hdf5")
126-
sim_data.to_file(file_path)
127-
Tidy3dStubData.postprocess(file_path)
128-
129-
# test warnings
130-
sim_data = make_sim_data()
131-
sim_data = sim_data.updated_copy(log="WARNING: messages were found in the solver log.")
132-
file_path = os.path.join(tmp_path, "test_warnings.hdf5")
133-
sim_data.to_file(file_path)
134-
Tidy3dStubData.postprocess(file_path)
121+
try:
122+
# test diverged
123+
sim_data = make_sim_data()
124+
sim_data = sim_data.updated_copy(diverged=True, log="The simulation has diverged!")
125+
file_path = os.path.join(tmp_path, "test_diverged.hdf5")
126+
sim_data.to_file(file_path)
127+
Tidy3dStubData.postprocess(file_path)
128+
129+
# test warnings
130+
sim_data = make_sim_data()
131+
sim_data = sim_data.updated_copy(log="WARNING: messages were found in the solver log.")
132+
file_path = os.path.join(tmp_path, "test_warnings.hdf5")
133+
sim_data.to_file(file_path)
134+
Tidy3dStubData.postprocess(file_path)
135+
finally:
136+
td.log.set_capture(False)
135137

136138

137139
@responses.activate
138140
def test_stub_data_lazy_loading(tmp_path):
139141
"""Tests the postprocess method with lazy loading of Tidy3dStubData when simulation diverged."""
140142
td.log.set_capture(True)
141143
sim_diverged_log = "The simulation has diverged!"
142-
143-
# make sim data where test diverged
144-
sim_data = make_sim_data()
145-
sim_data = sim_data.updated_copy(diverged=True, log=sim_diverged_log)
146-
file_path = os.path.join(tmp_path, "test_diverged.hdf5")
147-
sim_data.to_file(file_path)
148-
149-
# default case with lazy=False should output a warning
150-
with AssertLogLevel("WARNING", contains_str=sim_diverged_log):
151-
Tidy3dStubData.postprocess(file_path, lazy=False)
152-
153-
# we expect no warning in lazy mode as object should not be loaded
154-
with AssertLogLevel(None):
155-
sim_data = Tidy3dStubData.postprocess(file_path, lazy=True)
156-
157-
sim_data_copy = sim_data.copy()
158-
159-
# variable dict should only contain metadata to load the data, not the data itself
160-
assert is_lazy_object(sim_data)
161-
162-
# the type should be still SimulationData despite being lazy
163-
assert isinstance(sim_data, SimulationData)
164-
165-
# we expect a warning from the lazy object if some field is accessed
166-
with AssertLogLevel("WARNING", contains_str=sim_diverged_log):
167-
_ = sim_data_copy.monitor_data
144+
try:
145+
# make sim data where test diverged
146+
sim_data = make_sim_data()
147+
sim_data = sim_data.updated_copy(diverged=True, log=sim_diverged_log)
148+
file_path = os.path.join(tmp_path, "test_diverged.hdf5")
149+
sim_data.to_file(file_path)
150+
151+
# default case with lazy=False should output a warning
152+
with AssertLogLevel("WARNING", contains_str=sim_diverged_log):
153+
Tidy3dStubData.postprocess(file_path, lazy=False)
154+
155+
# we expect no warning in lazy mode as object should not be loaded
156+
with AssertLogLevel(None):
157+
sim_data = Tidy3dStubData.postprocess(file_path, lazy=True)
158+
159+
sim_data_copy = sim_data.copy()
160+
161+
# variable dict should only contain metadata to load the data, not the data itself
162+
assert is_lazy_object(sim_data)
163+
164+
# the type should be still SimulationData despite being lazy
165+
assert isinstance(sim_data, SimulationData)
166+
167+
# we expect a warning from the lazy object if some field is accessed
168+
with AssertLogLevel("WARNING", contains_str=sim_diverged_log):
169+
_ = sim_data_copy.monitor_data
170+
finally:
171+
td.log.set_capture(False)
168172

169173

170174
@pytest.mark.parametrize(

tidy3d/components/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from autograd.numpy.numpy_boxes import ArrayBox
2727
from autograd.tracer import isbox
2828
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator
29-
from pydantic.functional_validators import ModelWrapValidatorHandler
3029

3130
from tidy3d.exceptions import FileError
3231
from tidy3d.log import log
@@ -42,6 +41,7 @@
4241
from typing import Callable
4342

4443
from pydantic.fields import FieldInfo
44+
from pydantic.functional_validators import ModelWrapValidatorHandler
4545

4646
from tidy3d.compat import Self
4747

@@ -306,11 +306,16 @@ def _recursive_hash(value: Any) -> int:
306306
if isinstance(value, Tidy3dBaseModel):
307307
# This function needs to take special care because of mutable attributes inside of frozen pydantic models
308308
to_hash_list = []
309-
for k, v in dict(value).items():
309+
for k in type(value).model_fields:
310310
if k == "attrs":
311311
continue
312-
v_hash = Tidy3dBaseModel._recursive_hash(v)
312+
v_hash = Tidy3dBaseModel._recursive_hash(getattr(value, k))
313313
to_hash_list.append((k, v_hash))
314+
extra = getattr(value, "__pydantic_extra__", None)
315+
if extra:
316+
for k, v in extra.items():
317+
v_hash = Tidy3dBaseModel._recursive_hash(v)
318+
to_hash_list.append((k, v_hash))
314319
# attrs is mutable, use serialized output as safe hashing option
315320
if value.attrs:
316321
attrs_str = value._attrs_digest()

0 commit comments

Comments
 (0)