Skip to content

Commit 28d1022

Browse files
mj023hmgaudeckerclaudepre-commit-ci[bot]
authored
Reorder dimensions in productmap (#335)
Co-authored-by: Hans-Martin von Gaudecker <hmgaudecker@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 88ba7c2 commit 28d1022

6 files changed

Lines changed: 91 additions & 84 deletions

File tree

src/lcm/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class GridInitializationError(PyLCMError):
5050
"""Raised when there is an error in the grid initialization."""
5151

5252

53+
class FunctionDispatchError(PyLCMError):
54+
"""Raised when there is an error during the function dispatch."""
55+
56+
5357
def format_messages(errors: str | list[str]) -> str:
5458
"""Convert message or list of messages into a single string."""
5559
if isinstance(errors, str):

src/lcm/regime_building/variable_info.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from types import MappingProxyType
23

34
import pandas as pd
@@ -34,15 +35,32 @@ def get_variable_info(regime: Regime) -> pd.DataFrame:
3435
]
3536
info["is_discrete"] = ~info["is_continuous"]
3637

37-
order = info.query("is_discrete & is_state").index.tolist()
38-
order += info.query("is_discrete & is_action").index.tolist()
39-
order += info.query("is_continuous & is_state").index.tolist()
40-
order += info.query("is_continuous & is_action").index.tolist()
38+
ordered_discrete_states = sorted(
39+
info.query("is_discrete & is_state").index.tolist(),
40+
key=lambda x: (
41+
regime.states[x].batch_size
42+
if regime.states[x].batch_size != 0
43+
else math.inf
44+
),
45+
)
46+
ordered_continuous_states = sorted(
47+
info.query("is_continuous & is_state").index.tolist(),
48+
key=lambda x: (
49+
regime.states[x].batch_size
50+
if regime.states[x].batch_size != 0
51+
else math.inf
52+
),
53+
)
54+
ordered_states_and_actions = [
55+
*ordered_discrete_states,
56+
*ordered_continuous_states,
57+
*info.query("is_action").index.tolist(),
58+
]
4159

42-
if set(order) != set(info.index):
60+
if set(ordered_states_and_actions) != set(info.index):
4361
raise ValueError("Order and index do not match.")
4462

45-
return info.loc[order]
63+
return info.loc[ordered_states_and_actions]
4664

4765

4866
def get_grids(

src/lcm/utils/dispatchers.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import jax.numpy as jnp
99
from jax import Array, vmap
1010

11+
from lcm.exceptions import FunctionDispatchError
1112
from lcm.typing import Float1D, FloatND
1213
from lcm.utils.containers import find_duplicates
1314
from lcm.utils.functools import allow_args, allow_only_kwargs
@@ -224,7 +225,7 @@ def _base_productmap_batched(
224225
Like `jax.lax.map`, this function does not preserve the function signature.
225226
226227
Args:
227-
func: The function to be dispatched. Cannot have keyword-only arguments.
228+
func: The function to be dispatched. Cannot have positional-only parameters.
228229
product_axes: Tuple with names of arguments over which we apply
229230
`jax.lax.map`.
230231
batch_sizes: Dict with the batch sizes for each product_axis.
@@ -234,6 +235,13 @@ def _base_productmap_batched(
234235
235236
"""
236237
parameters = inspect.signature(func).parameters
238+
for name, param in parameters.items():
239+
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
240+
raise FunctionDispatchError(
241+
"Positional-only parameters are not allowed in dispatched functions. "
242+
f"The parameter '{name}' to the function {func.__name__} "
243+
"is POSITIONAL_ONLY."
244+
)
237245

238246
def batched_vmap(**kwargs: FloatND) -> FloatND:
239247
non_array_kwargs = {
@@ -250,14 +258,6 @@ def map_one_more(
250258
def func_mapped_over_one_more_axis(
251259
*already_mapped_args: Float1D, **already_mapped_kwargs: Float1D
252260
) -> FloatND:
253-
if parameters[axis].kind == inspect.Parameter.POSITIONAL_ONLY:
254-
return jax.lax.map(
255-
lambda axis_i: loop_func(
256-
axis_i, *already_mapped_args, **already_mapped_kwargs
257-
),
258-
jnp.atleast_1d(kwargs[axis]),
259-
batch_size=batch_sizes[axis],
260-
)
261261
return jax.lax.map(
262262
lambda axis_i: loop_func(
263263
*already_mapped_args, **{axis: axis_i}, **already_mapped_kwargs
@@ -271,6 +271,7 @@ def func_mapped_over_one_more_axis(
271271
# Loop over all product axes
272272
for axis in reversed(product_axes):
273273
func_with_partialled_args = map_one_more(func_with_partialled_args, axis)
274+
274275
return cast("FloatND", func_with_partialled_args())
275276

276277
return cast("FunctionWithArrayReturn", batched_vmap)

src/lcm_examples/mahler_yum_2024/_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def dead_is_active(age: int, initial_age: float) -> bool:
326326
"education": DiscreteGrid(Education),
327327
"productivity": DiscreteGrid(ProductivityType),
328328
"health_type": DiscreteGrid(HealthType),
329-
"discount_type": DiscreteGrid(DiscountType, batch_size=1),
329+
"discount_type": DiscreteGrid(DiscountType),
330330
},
331331
state_transitions={
332332
"wealth": next_wealth,
@@ -380,7 +380,7 @@ def dead_utility(discount_type: DiscreteState) -> FloatND: # noqa: ARG001
380380
transition=None,
381381
active=partial(dead_is_active, initial_age=ages.values[0]),
382382
states={
383-
"discount_type": DiscreteGrid(DiscountType, batch_size=1),
383+
"discount_type": DiscreteGrid(DiscountType),
384384
},
385385
functions={"utility": dead_utility},
386386
)

tests/regime_building/test_regime_processing.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from numpy.testing import assert_array_equal
88
from pandas.testing import assert_frame_equal
99

10-
from lcm import DiscreteGrid
1110
from lcm.ages import AgeGrid
11+
from lcm.grids import DiscreteGrid, LinSpacedGrid
1212
from lcm.regime_building.processing import (
1313
_rename_params_to_qnames,
1414
process_regimes,
@@ -76,6 +76,35 @@ def next_c(a, b):
7676
assert got["c"].codes == (0, 1)
7777

7878

79+
def test_get_grids_reorder(binary_category_class):
80+
def next_state(a, b):
81+
pass
82+
83+
regime_mock = RegimeMock(
84+
actions={
85+
"a": DiscreteGrid(binary_category_class),
86+
},
87+
states={
88+
"b": DiscreteGrid(binary_category_class),
89+
"c": DiscreteGrid(binary_category_class, batch_size=1),
90+
"d": LinSpacedGrid(start=0, stop=1, n_points=5, batch_size=3),
91+
"e": LinSpacedGrid(start=0, stop=1, n_points=5, batch_size=1),
92+
"f": LinSpacedGrid(start=0, stop=1, n_points=5),
93+
},
94+
state_transitions={
95+
"b": next_state,
96+
"c": next_state,
97+
"d": next_state,
98+
"e": next_state,
99+
"f": next_state,
100+
},
101+
functions={"utility": lambda _c: None},
102+
)
103+
104+
got = get_grids(regime_mock) # ty: ignore[invalid-argument-type]
105+
assert list(got.keys()) == ["c", "b", "e", "d", "f", "a"]
106+
107+
79108
def test_process_regimes():
80109
ages = AgeGrid(start=0, stop=4, step="Y")
81110
regimes = {"working_life": working_life, "dead": dead}
@@ -93,12 +122,12 @@ def test_process_regimes():
93122
# Variable Info
94123
assert (
95124
internal_working_regime.variable_info["is_state"].to_numpy()
96-
== np.array([False, True, False])
125+
== np.array([True, False, False])
97126
).all()
98127

99128
assert (
100129
internal_working_regime.variable_info["is_continuous"].to_numpy()
101-
== np.array([False, True, True])
130+
== np.array([True, False, True])
102131
).all()
103132

104133
# Grids — compare the grid objects (which now include transition attributes)

tests/test_dispatchers.py

Lines changed: 19 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from numpy.testing import assert_array_almost_equal as aaae
66

7+
from lcm.exceptions import FunctionDispatchError
78
from lcm.utils.dispatchers import (
89
productmap,
910
simulation_spacemap,
@@ -12,30 +13,14 @@
1213
from lcm.utils.functools import allow_args
1314

1415

15-
def f(a, /, *, b, c):
16-
"""Tests that dispatchers can handle positional-only and keyword-only arguments.
16+
def f(a, *, b, c):
17+
"""Tests that dispatchers can handle standard arguments and keyword-only arguments.
1718
18-
a is positional-only, b and c are keyword-only
19+
a is positional-or-keyword, b and c are keyword-only
1920
"""
2021
return jnp.sin(a) + jnp.cos(b) + jnp.tan(c)
2122

2223

23-
def f2(b, a, /, *, c):
24-
"""Tests that dispatchers can handle positional-only and keyword-only arguments.
25-
26-
b and a are positional-only, c is keyword-only
27-
"""
28-
return jnp.sin(a) + jnp.cos(b) + jnp.tan(c)
29-
30-
31-
def g(a, /, b, *, c, d):
32-
"""Tests that dispatchers can handle positional-only and keyword-only arguments.
33-
34-
a is positional-only, b is positional-or-keyword, c and d are keyword-only
35-
"""
36-
return f(a, b=b, c=c) + jnp.log(d)
37-
38-
3924
@pytest.fixture
4025
def setup_productmap_f():
4126
return {
@@ -57,34 +42,10 @@ def expected_productmap_f():
5742
return allow_args(f)(*helper).reshape(10, 7, 5)
5843

5944

60-
@pytest.fixture
61-
def setup_productmap_g():
62-
return {
63-
"a": jnp.linspace(-5, 5, 10),
64-
"b": jnp.linspace(0, 3, 7),
65-
"c": jnp.linspace(1, 5, 5),
66-
"d": jnp.linspace(1, 3, 4),
67-
}
68-
69-
70-
@pytest.fixture
71-
def expected_productmap_g():
72-
grids = {
73-
"a": jnp.linspace(-5, 5, 10),
74-
"b": jnp.linspace(0, 3, 7),
75-
"c": jnp.linspace(1, 5, 5),
76-
"d": jnp.linspace(1, 3, 4),
77-
}
78-
79-
helper = jnp.array(list(itertools.product(*grids.values()))).T
80-
return allow_args(g)(*helper).reshape(10, 7, 5, 4)
81-
82-
8345
@pytest.mark.parametrize(
8446
("func", "args", "grids", "expected"),
8547
[
8648
(f, ["a", "b", "c"], "setup_productmap_f", "expected_productmap_f"),
87-
(g, ["a", "b", "c", "d"], "setup_productmap_g", "expected_productmap_g"),
8849
],
8950
)
9051
def test_productmap_with_all_arguments_mapped(func, args, grids, expected, request):
@@ -112,24 +73,13 @@ def test_productmap_with_positional_args(setup_productmap_f):
11273
decorated(*setup_productmap_f.values()) # ty: ignore[missing-argument]
11374

11475

115-
def test_productmap_different_func_order(setup_productmap_f):
116-
_bs = dict.fromkeys(("a", "b", "c"), 0)
117-
decorated_f = productmap(func=f, variables=("a", "b", "c"), batch_sizes=_bs)
118-
expected = decorated_f(**setup_productmap_f) # ty: ignore[missing-argument]
119-
120-
decorated_f2 = productmap(func=f2, variables=("a", "b", "c"), batch_sizes=_bs)
121-
calculated_f2 = decorated_f2(**setup_productmap_f) # ty: ignore[missing-argument]
122-
123-
aaae(calculated_f2, expected)
124-
125-
12676
def test_productmap_change_arg_order(setup_productmap_f, expected_productmap_f):
12777
expected = jnp.transpose(expected_productmap_f, (1, 0, 2))
12878

12979
decorated = productmap(
13080
func=f, variables=("b", "a", "c"), batch_sizes=dict.fromkeys(("b", "a", "c"), 0)
13181
)
132-
calculated = decorated(**setup_productmap_f) # ty: ignore[missing-argument]
82+
calculated = decorated(**setup_productmap_f)
13383

13484
aaae(calculated, expected)
13585

@@ -148,7 +98,7 @@ def test_productmap_with_all_arguments_mapped_some_len_one():
14898
decorated = productmap(
14999
func=f, variables=("a", "b", "c"), batch_sizes=dict.fromkeys(("a", "b", "c"), 0)
150100
)
151-
calculated = decorated(**grids) # ty: ignore[missing-argument]
101+
calculated = decorated(**grids)
152102
aaae(calculated, expected)
153103

154104

@@ -166,7 +116,7 @@ def test_productmap_with_some_arguments_mapped():
166116
decorated = productmap(
167117
func=f, variables=("a", "c"), batch_sizes=dict.fromkeys(("a", "c"), 0)
168118
)
169-
calculated = decorated(**grids) # ty: ignore[missing-argument]
119+
calculated = decorated(**grids)
170120
aaae(calculated, expected)
171121

172122

@@ -201,6 +151,14 @@ def test_productmap_with_some_argument_mapped_twice():
201151
)
202152

203153

154+
def test_productmap_rejects_positional_only():
155+
def h(a, /, *, b):
156+
return a + b
157+
158+
with pytest.raises(FunctionDispatchError, match="POSITIONAL_ONLY"):
159+
productmap(func=h, variables=("a", "b"), batch_sizes={"a": 0, "b": 0})
160+
161+
204162
@pytest.fixture
205163
def setup_spacemap():
206164
value_grid = {
@@ -210,14 +168,12 @@ def setup_spacemap():
210168

211169
combination_values = {
212170
"c": jnp.array([7.0, 8, 9, 10]),
213-
"d": jnp.array([9.0, 10, 11, 12, 13]),
214171
}
215172

216173
helper = jnp.array(list(itertools.product(*combination_values.values()))).T
217174

218175
combination_grid = {
219176
"c": helper[0],
220-
"d": helper[1],
221177
}
222178
return value_grid, combination_grid
223179

@@ -231,13 +187,12 @@ def expected_spacemap():
231187

232188
combination_grid = {
233189
"c": jnp.array([7.0, 8, 9, 10]),
234-
"d": jnp.array([9.0, 10, 11, 12, 13]),
235190
}
236191

237192
all_grids = {**value_grid, **combination_grid}
238193
helper = jnp.array(list(itertools.product(*all_grids.values()))).T
239194

240-
return allow_args(g)(*helper).reshape(3, 2, 4 * 5)
195+
return allow_args(f)(*helper).reshape(3, 2, 4)
241196

242197

243198
def test_spacemap_all_arguments_mapped(
@@ -247,11 +202,11 @@ def test_spacemap_all_arguments_mapped(
247202
product_vars, combination_vars = setup_spacemap
248203

249204
decorated = simulation_spacemap(
250-
func=g,
205+
func=f,
251206
action_names=tuple(product_vars),
252207
state_names=tuple(combination_vars),
253208
)
254-
calculated = decorated(**product_vars, **combination_vars) # ty: ignore[missing-argument]
209+
calculated = decorated(**product_vars, **combination_vars)
255210

256211
aaae(calculated, jnp.transpose(expected_spacemap, axes=(2, 0, 1)))
257212

@@ -274,7 +229,7 @@ def test_spacemap_all_arguments_mapped(
274229
def test_spacemap_arguments_overlap(error_msg, product_vars, combination_vars):
275230
with pytest.raises(ValueError, match=error_msg):
276231
simulation_spacemap(
277-
func=g, action_names=product_vars, state_names=combination_vars
232+
func=f, action_names=product_vars, state_names=combination_vars
278233
)
279234

280235

0 commit comments

Comments
 (0)