Skip to content

Commit 7c42b2d

Browse files
tests + make sure reaction products is a list
1 parent 5390a10 commit 7c42b2d

2 files changed

Lines changed: 110 additions & 1 deletion

File tree

src/festim/exports/vtx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,11 +385,13 @@ def expression(T, **kwargs):
385385

386386
self.override_signature(expression, reactant_names, product_names)
387387

388+
reaction_products = reaction.product if isinstance(reaction.product, list) else [reaction.product]
389+
388390
super().__init__(
389391
filename=filename,
390392
expression=expression,
391393
species_dependent_value={
392-
spe.name: spe for spe in reaction.reactant + reaction.product
394+
spe.name: spe for spe in reaction.reactant + reaction_products
393395
},
394396
times=times,
395397
subdomain=subdomain,

test/test_vtx.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,110 @@ def test_custom_field_not_implemented_error(expression):
402402

403403
with pytest.raises(NotImplementedError):
404404
my_model.initialise()
405+
406+
407+
@pytest.mark.parametrize("direction", ["both", "forward", "backward"])
408+
@pytest.mark.parametrize("product_type", ["list", "single"])
409+
@pytest.mark.parametrize("p_0, E_p", [(0.01, 0.05), (0.01, 0.0), (0.0, 0.0)])
410+
def test_reaction_rate_export(tmp_path, direction, product_type, p_0, E_p):
411+
"""
412+
Test ReactionRate export functionality for different directions, product formats,
413+
and reaction configurations.
414+
"""
415+
if p_0 == 0.0 and direction == "backward":
416+
pytest.skip(
417+
"Backward direction export not supported when backward reaction is disabled"
418+
)
419+
my_model = F.HydrogenTransportProblem()
420+
mat = F.Material(D_0=1, E_D=0, K_S_0=1, E_K_S=0)
421+
vol = F.VolumeSubdomain(id=1, material=mat)
422+
top = F.SurfaceSubdomain(id=1, locator=lambda x: np.isclose(x[1], 1))
423+
bottom = F.SurfaceSubdomain(id=2, locator=lambda x: np.isclose(x[1], 0))
424+
left = F.SurfaceSubdomain(id=3, locator=lambda x: np.isclose(x[0], 0))
425+
right = F.SurfaceSubdomain(id=4, locator=lambda x: np.isclose(x[0], 1))
426+
427+
my_model.subdomains = [vol, top, bottom, left, right]
428+
429+
dolfinx_mesh = dolfinx.mesh.create_unit_square(MPI.COMM_WORLD, 10, 10)
430+
my_model.mesh = F.Mesh(dolfinx_mesh)
431+
432+
A = F.Species("A")
433+
B = F.Species("B")
434+
C = F.Species("C")
435+
436+
my_model.species = [A, B, C]
437+
438+
my_model.boundary_conditions = [
439+
F.FixedConcentrationBC(species=A, subdomain=top, value=1),
440+
F.FixedConcentrationBC(species=B, subdomain=left, value=1),
441+
F.FixedConcentrationBC(species=C, subdomain=bottom, value=0),
442+
]
443+
444+
reaction = F.Reaction(
445+
reactant=[A, B],
446+
product=[C] if product_type == "list" else C,
447+
k_0=1,
448+
E_k=0.1,
449+
p_0=p_0,
450+
E_p=E_p,
451+
volume=vol,
452+
)
453+
454+
my_model.reactions = [reaction]
455+
456+
my_model.temperature = 300
457+
458+
my_model.settings = F.Settings(transient=False, atol=1e-9, rtol=1e-9)
459+
460+
reaction_rate_export = F.ReactionRate(
461+
filename=tmp_path / f"reaction_rate_{direction}.bp",
462+
reaction=reaction,
463+
direction=direction,
464+
)
465+
466+
my_model.exports = [reaction_rate_export]
467+
468+
my_model.initialise()
469+
my_model.run()
470+
471+
472+
def test_reaction_rate_override_signature():
473+
"""
474+
Test that ReactionRate signature override correctly updates signatures.
475+
"""
476+
mat = F.Material(D_0=1, E_D=0)
477+
vol = F.VolumeSubdomain(id=1, material=mat)
478+
A = F.Species("A")
479+
B = F.Species("B")
480+
reaction = F.Reaction(
481+
reactant=[A], product=[B], k_0=1, E_k=0, p_0=0, E_p=0, volume=vol
482+
)
483+
484+
rr = F.ReactionRate(reaction=reaction, filename="dummy.bp")
485+
486+
def my_expression(x, y):
487+
return x + y
488+
489+
rr.override_signature(my_expression, ["A"], ["B"])
490+
import inspect
491+
492+
sig = inspect.signature(my_expression)
493+
assert set(sig.parameters.keys()) == {"T", "A", "B"}
494+
495+
496+
def test_export_base_class_times_and_extension(tmp_path):
497+
"""
498+
Test that ExportBaseClass sorts times and warns when wrong extension is given.
499+
"""
500+
with pytest.warns(UserWarning, match="does not have .bp extension"):
501+
export = F.ExportBaseClass(
502+
filename=tmp_path / "wrong_extension.txt", ext=".bp", times=[3.0, 1.0, 2.0]
503+
)
504+
505+
assert export.filename.suffix == ".bp"
506+
assert export.times == [1.0, 2.0, 3.0]
507+
508+
509+
def test_export_base_class_no_times(tmp_path):
510+
export = F.ExportBaseClass(filename=tmp_path / "correct.bp", ext=".bp", times=None)
511+
assert export.times is None

0 commit comments

Comments
 (0)