Skip to content

Commit 04a7371

Browse files
Validate aggregation callables with single-argument bind check (#290)
Validate aggregation callables with single-argument bind check
2 parents b4e5aae + 5b53b55 commit 04a7371

File tree

2 files changed

+69
-9
lines changed

2 files changed

+69
-9
lines changed

neat/aggregations.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
and code for adding new user-defined ones.
44
"""
55

6+
import inspect
67
import types
78
import warnings
89
from functools import reduce
@@ -48,15 +49,25 @@ class InvalidAggregationFunction(TypeError):
4849
pass
4950

5051

51-
def validate_aggregation(function): # TODO: Recognize when need `reduce`
52-
if not isinstance(function,
53-
(types.BuiltinFunctionType,
54-
types.FunctionType,
55-
types.LambdaType)):
56-
raise InvalidAggregationFunction("A function object is required.")
57-
58-
if not (function.__code__.co_argcount >= 1):
59-
raise InvalidAggregationFunction("A function taking at least one argument is required")
52+
def validate_aggregation(function):
53+
if not callable(function):
54+
raise InvalidAggregationFunction("A callable object is required.")
55+
56+
try:
57+
signature = inspect.signature(function)
58+
except (TypeError, ValueError) as exc:
59+
# CPython builtins (e.g. max, sum) often lack introspectable signatures.
60+
# Skip signature validation for these; they are assumed correct.
61+
if isinstance(function, types.BuiltinFunctionType):
62+
return
63+
raise InvalidAggregationFunction("Unable to inspect aggregation callable signature.") from exc
64+
65+
try:
66+
signature.bind(object())
67+
except TypeError as exc:
68+
raise InvalidAggregationFunction(
69+
"A callable with exactly one required positional argument is required"
70+
) from exc
6071

6172

6273
class AggregationFunctionSet:

tests/test_aggregation.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,29 @@ def test_add_minabs():
7272
assert config.genome_config.aggregation_function_defs.is_valid('minabs')
7373

7474

75+
def test_add_builtin_max():
76+
local_dir = os.path.dirname(__file__)
77+
config_path = os.path.join(local_dir, 'test_configuration')
78+
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
79+
neat.DefaultSpeciesSet, neat.DefaultStagnation,
80+
config_path)
81+
82+
config.genome_config.add_aggregation('builtin_max', max)
83+
assert config.genome_config.aggregation_function_defs.get('builtin_max') is max
84+
85+
7586
def dud_function():
7687
return 0.0
7788

7889

90+
def keyword_only_function(*, items):
91+
return sum(items)
92+
93+
94+
def two_argument_function(items, scale):
95+
return sum(items) * scale
96+
97+
7998
def test_function_set():
8099
s = aggregations.AggregationFunctionSet()
81100
assert s.get('sum') is not None
@@ -135,6 +154,36 @@ def test_bad_add2():
135154
raise Exception("Should have had a TypeError/derived for dud_function")
136155

137156

157+
def test_bad_add3():
158+
local_dir = os.path.dirname(__file__)
159+
config_path = os.path.join(local_dir, 'test_configuration')
160+
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
161+
neat.DefaultSpeciesSet, neat.DefaultStagnation,
162+
config_path)
163+
164+
try:
165+
config.genome_config.add_aggregation('keyword_only_function', keyword_only_function)
166+
except TypeError:
167+
pass
168+
else:
169+
raise Exception("Should have had a TypeError/derived for keyword_only_function")
170+
171+
172+
def test_bad_add4():
173+
local_dir = os.path.dirname(__file__)
174+
config_path = os.path.join(local_dir, 'test_configuration')
175+
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
176+
neat.DefaultSpeciesSet, neat.DefaultStagnation,
177+
config_path)
178+
179+
try:
180+
config.genome_config.add_aggregation('two_argument_function', two_argument_function)
181+
except TypeError:
182+
pass
183+
else:
184+
raise Exception("Should have had a TypeError/derived for two_argument_function")
185+
186+
138187
if __name__ == '__main__':
139188
test_sum()
140189
test_product()

0 commit comments

Comments
 (0)