@@ -92,6 +92,12 @@ def stable_hash(self) -> bytes:
9292
9393@dataclasses .dataclass (frozen = True )
9494class DirectScalarType :
95+ """
96+ Represents a scalar value that is passed directly to the remote function.
97+
98+ For these values, BigQuery handles the serialization and deserialization without any additional processing.
99+ """
100+
95101 _py_type : type
96102
97103 @property
@@ -119,16 +125,26 @@ def stable_hash(self) -> bytes:
119125 def from_sdk_type (cls , sdk_type : bigquery .StandardSqlDataType ) -> DirectScalarType :
120126 return cls (function_typing .sdk_type_to_py_type (sdk_type ))
121127
128+ @property
129+ def emulating_type (self ) -> DirectScalarType :
130+ return self
131+
122132
123133@dataclasses .dataclass (frozen = True )
124134class VirtualListTypeV1 :
135+ """
136+ Represents a list of scalar values that is emulated as a JSON array string in the remote function.
137+
138+ Only works as output paramter right now where array -> string in function runtime, and then string -> array in SQL post-processing (defined in out_expr()).
139+ """
140+
125141 _PROTOCOL_ID = "virtual_list_v1"
126142
127143 inner_dtype : DirectScalarType
128144
129145 @property
130146 def py_type (self ) -> Type [list [Any ]]:
131- return list [function_typing . sdk_type_to_py_type ( self .inner_dtype ) ] # type: ignore
147+ return list [self .inner_dtype . py_type ] # type: ignore
132148
133149 # TODO: Specify emulating type and mapping expressions between said types
134150 @property
@@ -163,6 +179,8 @@ def stable_hash(self) -> bytes:
163179class RowSeriesInputFieldV1 :
164180 """
165181 Used to handle functions that logically take a series as an input, but handled via a string protocol in the remote function.
182+
183+ For these, the serialization is dependent on index metadata, which must be provided by the caller.
166184 """
167185
168186 _PROTOCOL_ID = "row_series_input_v1"
@@ -180,6 +198,11 @@ def bf_type(self) -> bigframes.dtypes.Dtype:
180198 def sql_type (self ) -> str :
181199 return "STRING"
182200
201+ @property
202+ def emulating_type (self ) -> DirectScalarType :
203+ # Regardless of list inner type, string is used to emulate the list in the remote function.
204+ return DirectScalarType (str )
205+
183206 def stable_hash (self ) -> bytes :
184207 hash_val = hashlib .md5 ()
185208 hash_val .update (self ._PROTOCOL_ID .encode ())
@@ -196,14 +219,14 @@ class UdfSignature:
196219 output : DirectScalarType | VirtualListTypeV1
197220
198221 def __post_init__ (self ):
199- if any (isinstance (arg , RowSeriesInputFieldV1 ) for arg in self .inputs ):
200- if len (self .inputs ) != 1 :
201- raise ValueError ("Row processor functions must have exactly one input." )
202222 assert all (isinstance (arg , UdfArg ) for arg in self .inputs )
203223 assert isinstance (self .output , (DirectScalarType , VirtualListTypeV1 ))
204224
205225 def to_sql_input_signature (self ) -> str :
206- return "," .join (f"{ field .name } { field .sql_type } " for field in self .inputs )
226+ return "," .join (
227+ f"{ field .name } { field .sql_type } "
228+ for field in self .with_devirtualize ().inputs
229+ )
207230
208231 @property
209232 def protocol_metadata (self ) -> str :
@@ -214,16 +237,20 @@ def protocol_metadata(self) -> str:
214237 python_output_type = self .output .py_type
215238 )
216239
240+ @property
241+ def is_virtual (self ) -> bool :
242+ dtypes = (self .output ,) + tuple (arg .dtype for arg in self .inputs )
243+ return not all (isinstance (dtype , DirectScalarType ) for dtype in dtypes )
244+
217245 @property
218246 def is_row_processor (self ) -> bool :
219- return any (isinstance (arg , RowSeriesInputFieldV1 ) for arg in self .inputs )
247+ return any (isinstance (arg . dtype , RowSeriesInputFieldV1 ) for arg in self .inputs )
220248
221249 def with_devirtualize (self ) -> UdfSignature :
222- if isinstance (self .output , DirectScalarType ):
223- return self
224- assert isinstance (self .output , VirtualListTypeV1 )
225250 return UdfSignature (
226- inputs = self .inputs ,
251+ inputs = tuple (
252+ UdfArg (arg .name , arg .dtype .emulating_type ) for arg in self .inputs
253+ ),
227254 output = self .output .emulating_type ,
228255 )
229256
@@ -245,7 +272,7 @@ def from_routine(cls, routine: bigquery.Routine) -> UdfSignature:
245272 if python_output_type := bigframes .functions ._utils .get_python_output_type_from_bigframes_metadata (
246273 routine .description
247274 ):
248- if routine . return_type is None or bq_return_type .type_kind != "STRING" :
275+ if bq_return_type .type_kind != "STRING" :
249276 raise bf_formatting .create_exception_with_feedback_link (
250277 TypeError ,
251278 "An explicit output_type should be provided only for a BigQuery function with STRING output." ,
@@ -280,7 +307,7 @@ def from_py_signature(cls, signature: inspect.Signature):
280307 ValueError ,
281308 "'input_types' was not set and parameter "
282309 f"'{ parameter .name } ' is missing a type annotation. "
283- "Types are required to use @remote_function ." ,
310+ "Types are required to use udfs ." ,
284311 )
285312
286313 input_types .append (UdfArg .from_py_param (parameter ))
@@ -290,18 +317,22 @@ def from_py_signature(cls, signature: inspect.Signature):
290317 ValueError ,
291318 "'output_type' was not set and function is missing a "
292319 "return type annotation. Types are required to use "
293- "@remote_function ." ,
320+ "udfs ." ,
294321 )
295322
296- if get_origin (signature .return_annotation ) is list :
297- inner_py_type = get_args (signature .return_annotation )[0 ]
298- virtual_list_output_type = VirtualListTypeV1 (
299- DirectScalarType (inner_py_type )
300- )
301- return cls (tuple (input_types ), virtual_list_output_type )
302- else :
303- direct_output_type = DirectScalarType (signature .return_annotation )
304- return cls (tuple (input_types ), direct_output_type )
323+ output_type = DirectScalarType (signature .return_annotation )
324+ return cls (tuple (input_types ), output_type )
325+
326+ def to_remote_function_compatible (self ) -> UdfSignature :
327+ # need to virtualize list outputs
328+ if isinstance (self .output , DirectScalarType ):
329+ if get_origin (self .output .py_type ) is list :
330+ inner_py_type = get_args (self .output .py_type )[0 ]
331+ return UdfSignature (
332+ inputs = self .inputs ,
333+ output = VirtualListTypeV1 (DirectScalarType (inner_py_type )),
334+ )
335+ return self
305336
306337 def stable_hash (self ) -> bytes :
307338 hash_val = hashlib .md5 ()
@@ -321,6 +352,8 @@ class BigqueryUdf:
321352 signature : UdfSignature
322353
323354 def with_devirtualize (self ) -> BigqueryUdf :
355+ if not self .signature .is_virtual :
356+ return self
324357 return BigqueryUdf (
325358 routine_ref = self .routine_ref ,
326359 signature = self .signature .with_devirtualize (),
0 commit comments