44import pytest
55from numpy .testing import assert_array_almost_equal as aaae
66
7+ from lcm .exceptions import FunctionDispatchError
78from lcm .utils .dispatchers import (
89 productmap ,
910 simulation_spacemap ,
1213from lcm .utils .functools import allow_args
1314
1415
15- def f (a , / , * , b , c ):
16- """Tests that dispatchers can handle positional-only and keyword-only arguments.
16+ def f (a , * , b , c ):
17+ """Tests that dispatchers can handle standard arguments and keyword-only arguments.
1718
18- a is positional-only , b and c are keyword-only
19+ a is positional-or-keyword , b and c are keyword-only
1920 """
2021 return jnp .sin (a ) + jnp .cos (b ) + jnp .tan (c )
2122
2223
23- def f2 (b , a , / , * , c ):
24- """Tests that dispatchers can handle positional-only and keyword-only arguments.
25-
26- b and a are positional-only, c is keyword-only
27- """
28- return jnp .sin (a ) + jnp .cos (b ) + jnp .tan (c )
29-
30-
31- def g (a , / , b , * , c , d ):
32- """Tests that dispatchers can handle positional-only and keyword-only arguments.
33-
34- a is positional-only, b is positional-or-keyword, c and d are keyword-only
35- """
36- return f (a , b = b , c = c ) + jnp .log (d )
37-
38-
3924@pytest .fixture
4025def setup_productmap_f ():
4126 return {
@@ -57,34 +42,10 @@ def expected_productmap_f():
5742 return allow_args (f )(* helper ).reshape (10 , 7 , 5 )
5843
5944
60- @pytest .fixture
61- def setup_productmap_g ():
62- return {
63- "a" : jnp .linspace (- 5 , 5 , 10 ),
64- "b" : jnp .linspace (0 , 3 , 7 ),
65- "c" : jnp .linspace (1 , 5 , 5 ),
66- "d" : jnp .linspace (1 , 3 , 4 ),
67- }
68-
69-
70- @pytest .fixture
71- def expected_productmap_g ():
72- grids = {
73- "a" : jnp .linspace (- 5 , 5 , 10 ),
74- "b" : jnp .linspace (0 , 3 , 7 ),
75- "c" : jnp .linspace (1 , 5 , 5 ),
76- "d" : jnp .linspace (1 , 3 , 4 ),
77- }
78-
79- helper = jnp .array (list (itertools .product (* grids .values ()))).T
80- return allow_args (g )(* helper ).reshape (10 , 7 , 5 , 4 )
81-
82-
8345@pytest .mark .parametrize (
8446 ("func" , "args" , "grids" , "expected" ),
8547 [
8648 (f , ["a" , "b" , "c" ], "setup_productmap_f" , "expected_productmap_f" ),
87- (g , ["a" , "b" , "c" , "d" ], "setup_productmap_g" , "expected_productmap_g" ),
8849 ],
8950)
9051def test_productmap_with_all_arguments_mapped (func , args , grids , expected , request ):
@@ -112,24 +73,13 @@ def test_productmap_with_positional_args(setup_productmap_f):
11273 decorated (* setup_productmap_f .values ()) # ty: ignore[missing-argument]
11374
11475
115- def test_productmap_different_func_order (setup_productmap_f ):
116- _bs = dict .fromkeys (("a" , "b" , "c" ), 0 )
117- decorated_f = productmap (func = f , variables = ("a" , "b" , "c" ), batch_sizes = _bs )
118- expected = decorated_f (** setup_productmap_f ) # ty: ignore[missing-argument]
119-
120- decorated_f2 = productmap (func = f2 , variables = ("a" , "b" , "c" ), batch_sizes = _bs )
121- calculated_f2 = decorated_f2 (** setup_productmap_f ) # ty: ignore[missing-argument]
122-
123- aaae (calculated_f2 , expected )
124-
125-
12676def test_productmap_change_arg_order (setup_productmap_f , expected_productmap_f ):
12777 expected = jnp .transpose (expected_productmap_f , (1 , 0 , 2 ))
12878
12979 decorated = productmap (
13080 func = f , variables = ("b" , "a" , "c" ), batch_sizes = dict .fromkeys (("b" , "a" , "c" ), 0 )
13181 )
132- calculated = decorated (** setup_productmap_f ) # ty: ignore[missing-argument]
82+ calculated = decorated (** setup_productmap_f )
13383
13484 aaae (calculated , expected )
13585
@@ -148,7 +98,7 @@ def test_productmap_with_all_arguments_mapped_some_len_one():
14898 decorated = productmap (
14999 func = f , variables = ("a" , "b" , "c" ), batch_sizes = dict .fromkeys (("a" , "b" , "c" ), 0 )
150100 )
151- calculated = decorated (** grids ) # ty: ignore[missing-argument]
101+ calculated = decorated (** grids )
152102 aaae (calculated , expected )
153103
154104
@@ -166,7 +116,7 @@ def test_productmap_with_some_arguments_mapped():
166116 decorated = productmap (
167117 func = f , variables = ("a" , "c" ), batch_sizes = dict .fromkeys (("a" , "c" ), 0 )
168118 )
169- calculated = decorated (** grids ) # ty: ignore[missing-argument]
119+ calculated = decorated (** grids )
170120 aaae (calculated , expected )
171121
172122
@@ -201,6 +151,14 @@ def test_productmap_with_some_argument_mapped_twice():
201151 )
202152
203153
154+ def test_productmap_rejects_positional_only ():
155+ def h (a , / , * , b ):
156+ return a + b
157+
158+ with pytest .raises (FunctionDispatchError , match = "POSITIONAL_ONLY" ):
159+ productmap (func = h , variables = ("a" , "b" ), batch_sizes = {"a" : 0 , "b" : 0 })
160+
161+
204162@pytest .fixture
205163def setup_spacemap ():
206164 value_grid = {
@@ -210,14 +168,12 @@ def setup_spacemap():
210168
211169 combination_values = {
212170 "c" : jnp .array ([7.0 , 8 , 9 , 10 ]),
213- "d" : jnp .array ([9.0 , 10 , 11 , 12 , 13 ]),
214171 }
215172
216173 helper = jnp .array (list (itertools .product (* combination_values .values ()))).T
217174
218175 combination_grid = {
219176 "c" : helper [0 ],
220- "d" : helper [1 ],
221177 }
222178 return value_grid , combination_grid
223179
@@ -231,13 +187,12 @@ def expected_spacemap():
231187
232188 combination_grid = {
233189 "c" : jnp .array ([7.0 , 8 , 9 , 10 ]),
234- "d" : jnp .array ([9.0 , 10 , 11 , 12 , 13 ]),
235190 }
236191
237192 all_grids = {** value_grid , ** combination_grid }
238193 helper = jnp .array (list (itertools .product (* all_grids .values ()))).T
239194
240- return allow_args (g )(* helper ).reshape (3 , 2 , 4 * 5 )
195+ return allow_args (f )(* helper ).reshape (3 , 2 , 4 )
241196
242197
243198def test_spacemap_all_arguments_mapped (
@@ -247,11 +202,11 @@ def test_spacemap_all_arguments_mapped(
247202 product_vars , combination_vars = setup_spacemap
248203
249204 decorated = simulation_spacemap (
250- func = g ,
205+ func = f ,
251206 action_names = tuple (product_vars ),
252207 state_names = tuple (combination_vars ),
253208 )
254- calculated = decorated (** product_vars , ** combination_vars ) # ty: ignore[missing-argument]
209+ calculated = decorated (** product_vars , ** combination_vars )
255210
256211 aaae (calculated , jnp .transpose (expected_spacemap , axes = (2 , 0 , 1 )))
257212
@@ -274,7 +229,7 @@ def test_spacemap_all_arguments_mapped(
274229def test_spacemap_arguments_overlap (error_msg , product_vars , combination_vars ):
275230 with pytest .raises (ValueError , match = error_msg ):
276231 simulation_spacemap (
277- func = g , action_names = product_vars , state_names = combination_vars
232+ func = f , action_names = product_vars , state_names = combination_vars
278233 )
279234
280235
0 commit comments