Skip to content

Commit 548a8b4

Browse files
committed
Validate aggregation callables accept single positional argument
1 parent deece7b commit 548a8b4

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

neat/aggregations.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class InvalidAggregationFunction(TypeError):
4949
pass
5050

5151

52-
def validate_aggregation(function): # TODO: Recognize when need `reduce`
52+
def validate_aggregation(function):
5353
if not callable(function):
5454
raise InvalidAggregationFunction("A callable object is required.")
5555

@@ -60,17 +60,12 @@ def validate_aggregation(function): # TODO: Recognize when need `reduce`
6060
return
6161
raise InvalidAggregationFunction("Unable to inspect aggregation callable signature.") from exc
6262

63-
accepts_positional = any(
64-
parameter.kind in (
65-
inspect.Parameter.POSITIONAL_ONLY,
66-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
67-
inspect.Parameter.VAR_POSITIONAL,
68-
)
69-
for parameter in signature.parameters.values()
70-
)
71-
72-
if not accepts_positional:
73-
raise InvalidAggregationFunction("A function taking at least one positional argument is required")
63+
try:
64+
signature.bind(object())
65+
except TypeError as exc:
66+
raise InvalidAggregationFunction(
67+
"A function taking a single positional argument is required"
68+
) from exc
7469

7570

7671
class AggregationFunctionSet:

tests/test_aggregation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def keyword_only_function(*, items):
9191
return sum(items)
9292

9393

94+
def two_argument_function(items, scale):
95+
return sum(items) * scale
96+
97+
9498
def test_function_set():
9599
s = aggregations.AggregationFunctionSet()
96100
assert s.get('sum') is not None
@@ -165,6 +169,21 @@ def test_bad_add3():
165169
raise Exception("Should have had a TypeError/derived for keyword_only_function")
166170

167171

172+
def test_bad_add4():
173+
local_dir = os.path.dirname(__file__)
174+
config_path = os.path.join(local_dir, 'test_configuration')
175+
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
176+
neat.DefaultSpeciesSet, neat.DefaultStagnation,
177+
config_path)
178+
179+
try:
180+
config.genome_config.add_aggregation('two_argument_function', two_argument_function)
181+
except TypeError:
182+
pass
183+
else:
184+
raise Exception("Should have had a TypeError/derived for two_argument_function")
185+
186+
168187
if __name__ == '__main__':
169188
test_sum()
170189
test_product()

0 commit comments

Comments
 (0)