Skip to content

Commit d4a95a7

Browse files
authored
adjust to new params format with optional param in decorator (#82)
Signed-off-by: Walter Martin <wamartin@microsoft.com>
1 parent 9c62fb5 commit d4a95a7

3 files changed

Lines changed: 36 additions & 41 deletions

File tree

inference_schema/schema_decorators.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ._constants import INPUT_SCHEMA_ATTR, OUTPUT_SCHEMA_ATTR
1313

1414

15-
def input_schema(param_name, param_type, convert_to_provided_type=True):
15+
def input_schema(param_name, param_type, convert_to_provided_type=True, optional=False):
1616
"""
1717
Decorator to define an input schema model for a function parameter
1818
The input schema is a representation of what type the function expects
@@ -46,14 +46,16 @@ def decorator_input(user_run, instance, args, kwargs):
4646
if convert_to_provided_type:
4747
args = list(args)
4848

49-
if param_name not in kwargs.keys():
49+
if param_name not in kwargs.keys() and not optional:
5050
decorators = _get_decorators(user_run)
5151
arg_names = inspect.getfullargspec(decorators[-1]).args
5252
if param_name not in arg_names:
5353
raise Exception('Error, provided param_name "{}" '
5454
'is not in the decorated function.'.format(param_name))
5555
param_position = arg_names.index(param_name)
5656
args[param_position] = _deserialize_input_argument(args[param_position], param_type, param_name)
57+
elif param_name not in kwargs.keys() and optional:
58+
pass
5759
else:
5860
kwargs[param_name] = _deserialize_input_argument(kwargs[param_name], param_type, param_name)
5961

tests/conftest.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,14 @@ def pandas_url_func(param):
186186

187187
@pytest.fixture(scope="session")
188188
def decorated_pandas_func_parameters(pandas_sample_input_for_params, sample_param_dict):
189-
@input_schema('input_data', StandardPythonParameterType({
190-
'split_df': PandasParameterType(pandas_sample_input_for_params, orient='split'),
191-
'parameters': StandardPythonParameterType(sample_param_dict)
192-
}))
193-
def pandas_params_func(input_data):
194-
assert type(input_data) is dict
195-
assert type(input_data["split_df"]) is pd.DataFrame
196-
if 'parameters' in input_data:
197-
assert type(input_data["parameters"]) is dict
198-
beams = input_data['parameters']['num_beams'] if 'parameters' in input_data else 0
199-
return input_data["split_df"]["sentence1"], beams
189+
@input_schema('input_data', PandasParameterType(pandas_sample_input_for_params, orient='split'))
190+
@input_schema('params', StandardPythonParameterType(sample_param_dict), optional=True)
191+
def pandas_params_func(input_data, params=None):
192+
assert type(input_data) is pd.DataFrame
193+
if params is not None:
194+
assert type(params) is dict
195+
beams = params['num_beams'] if params is not None else 0
196+
return input_data["sentence1"], beams
200197

201198
return pandas_params_func
202199

tests/test_pandas_parameter_type.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -79,38 +79,34 @@ def test_pandas_categorical_handling(self, decorated_pandas_categorical_func):
7979
assert categorical == result
8080

8181
def test_pandas_params_handling(self, decorated_pandas_func_parameters):
82-
pandas_input_data = {"input_data": {
83-
"split_df": {
84-
"columns": [
85-
"sentence1"
86-
],
87-
"data": [
88-
["this is a string starting with"]
89-
],
90-
"index": [0]
91-
},
92-
"parameters": {
93-
"num_beams": 2,
94-
"max_length": 512
95-
}
96-
}}
97-
result = decorated_pandas_func_parameters(**pandas_input_data)
82+
pandas_input_data = {
83+
"columns": [
84+
"sentence1"
85+
],
86+
"data": [
87+
["this is a string starting with"]
88+
],
89+
"index": [0]
90+
}
91+
parameters = {
92+
"num_beams": 2,
93+
"max_length": 512
94+
}
95+
result = decorated_pandas_func_parameters(pandas_input_data, params=parameters)
9896
assert result[0][0] == "this is a string starting with"
9997
assert result[1] == 2
10098

10199
def test_pandas_params_handling_without_params(self, decorated_pandas_func_parameters):
102-
pandas_input_data = {"input_data": {
103-
"split_df": {
104-
"columns": [
105-
"sentence1"
106-
],
107-
"data": [
108-
["this is a string starting with"]
109-
],
110-
"index": [0]
111-
}
112-
}}
113-
result = decorated_pandas_func_parameters(**pandas_input_data)
100+
pandas_input_data = {
101+
"columns": [
102+
"sentence1"
103+
],
104+
"data": [
105+
["this is a string starting with"]
106+
],
107+
"index": [0]
108+
}
109+
result = decorated_pandas_func_parameters(pandas_input_data)
114110
assert result[0][0] == "this is a string starting with"
115111
assert result[1] == 0
116112

0 commit comments

Comments
 (0)