Skip to content

Commit 847d4a0

Browse files
committed
cond. free syms
1 parent 87b699e commit 847d4a0

File tree

4 files changed

+45
-24
lines changed

4 files changed

+45
-24
lines changed

petab/v2/conditions.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,11 @@
22

33
from __future__ import annotations
44

5-
from itertools import chain
65
from pathlib import Path
76

87
import pandas as pd
9-
import sympy as sp
108

11-
from .. import v2
129
from ..v1.lint import assert_no_leading_trailing_whitespace
13-
from .C import *
1410

1511
__all__ = [
1612
"get_condition_df",
@@ -50,20 +46,3 @@ def write_condition_df(df: pd.DataFrame, filename: str | Path) -> None:
5046
"""
5147
df = get_condition_df(df)
5248
df.to_csv(filename, sep="\t", index=False)
53-
54-
55-
def get_condition_table_free_symbols(problem: v2.Problem) -> set[sp.Basic]:
56-
"""Free symbols from condition table assignments.
57-
58-
Collects all free symbols from the condition table `targetValue` column.
59-
60-
:returns: Set of free symbols.
61-
"""
62-
return set(
63-
chain.from_iterable(
64-
change.target_value.free_symbols
65-
for condition in problem.conditions_table.conditions
66-
for change in condition.changes
67-
if change.target_value is not None
68-
)
69-
)

petab/v2/core.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
from collections.abc import Sequence
77
from enum import Enum
8+
from itertools import chain
89
from pathlib import Path
910
from typing import Annotated, Literal
1011

@@ -468,6 +469,23 @@ def __iadd__(self, other: Condition) -> ConditionsTable:
468469
self.conditions.append(other)
469470
return self
470471

472+
@property
473+
def free_symbols(self) -> set[sp.Symbol]:
474+
"""Get all free symbols in the conditions table.
475+
476+
This includes all free symbols in the target values of the changes,
477+
independently of whether it is referenced by any experiment, or
478+
(indirectly) by any measurement.
479+
"""
480+
return set(
481+
chain.from_iterable(
482+
change.target_value.free_symbols
483+
for condition in self.conditions
484+
for change in condition.changes
485+
if change.target_value is not None
486+
)
487+
)
488+
471489

472490
class ExperimentPeriod(BaseModel):
473491
"""A period of a timecourse or experiment defined by a start time

petab/v2/lint.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import pandas as pd
1414
import sympy as sp
1515

16-
from .. import v2
1716
from .problem import Problem
1817

1918
logger = logging.getLogger(__name__)
@@ -743,7 +742,7 @@ def append_overrides(overrides):
743742
append_overrides(measurement.noise_parameters)
744743

745744
# Append parameter overrides from condition table
746-
for p in v2.conditions.get_condition_table_free_symbols(problem):
745+
for p in problem.conditions_table.free_symbols:
747746
parameter_ids[str(p)] = None
748747

749748
return set(parameter_ids.keys())
@@ -822,7 +821,7 @@ def append_overrides(overrides):
822821
# model
823822
parameter_ids.update(
824823
str(p)
825-
for p in v2.conditions.get_condition_table_free_symbols(problem)
824+
for p in problem.conditions_table.free_symbols
826825
if not problem.model.has_entity_with_id(str(p))
827826
)
828827

tests/v2/test_core.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,28 @@ def test_experiment():
251251

252252
with pytest.raises(ValidationError, match="Invalid ID"):
253253
Experiment(id="experiment 1")
254+
255+
256+
def test_conditions_table():
257+
assert ConditionsTable().free_symbols == set()
258+
259+
assert (
260+
ConditionsTable(
261+
conditions=[
262+
Condition(
263+
id="condition1",
264+
changes=[Change(target_id="k1", target_value="true")],
265+
)
266+
]
267+
).free_symbols
268+
== set()
269+
)
270+
271+
assert ConditionsTable(
272+
conditions=[
273+
Condition(
274+
id="condition1",
275+
changes=[Change(target_id="k1", target_value=x / y)],
276+
)
277+
]
278+
).free_symbols == {x, y}

0 commit comments

Comments
 (0)