@@ -88,12 +88,8 @@ def __call__(
8888 return MyTableProvider (* args )
8989
9090
91- def test_python_table_function ():
92- ctx = SessionContext ()
93- table_func = PythonTableFunction ()
94- table_udtf = udtf (table_func , "my_table_func" )
95- ctx .register_udtf (table_udtf )
96- result = ctx .sql ("select * from my_table_func(3,2,4)" ).collect ()
91+ def common_table_function_test (test_ctx : SessionContext ) -> None :
92+ result = test_ctx .sql ("select * from my_table_func(3,2,4)" ).collect ()
9793
9894 assert len (result ) == 4
9995 assert result [0 ].num_columns == 3
@@ -108,3 +104,31 @@ def test_python_table_function():
108104 ]
109105
110106 assert result == expected
107+
108+
109+ def test_python_table_function ():
110+ ctx = SessionContext ()
111+ table_func = PythonTableFunction ()
112+ table_udtf = udtf (table_func , "my_table_func" )
113+ ctx .register_udtf (table_udtf )
114+
115+ common_table_function_test (ctx )
116+
117+
118+ def test_python_table_function_decorator ():
119+ ctx = SessionContext ()
120+
121+ @udtf ("my_table_func" )
122+ def my_udtf (
123+ num_cols : Expr , num_rows : Expr , num_batches : Expr
124+ ) -> TableProviderExportable :
125+ args = [
126+ num_cols .to_variant ().value_i64 (),
127+ num_rows .to_variant ().value_i64 (),
128+ num_batches .to_variant ().value_i64 (),
129+ ]
130+ return MyTableProvider (* args )
131+
132+ ctx .register_udtf (my_udtf )
133+
134+ common_table_function_test (ctx )
0 commit comments