Skip to content

Commit 36f2b13

Browse files
committed
Fix fixed length strings / binary; add tests for fixed strings / binary; test no args / no return value
1 parent 2da33ff commit 36f2b13

6 files changed

Lines changed: 132 additions & 38 deletions

File tree

accel.c

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3929,7 +3929,8 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
39293929
out_idx += 8;
39303930

39313931
} else if (col_types[i].type == NUMPY_FIXED_STRING) {
3932-
void *bytes = (void*)(cols[i] + j * 8);
3932+
// Jump to col_types[i].length * 4 for UCS4 fixed length string
3933+
void *bytes = (void*)(cols[i] + j * col_types[i].length * 4);
39333934

39343935
if (bytes == NULL) {
39353936
CHECKMEM(8);
@@ -3944,6 +3945,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
39443945
if (utf8_str) free(utf8_str);
39453946
goto error;
39463947
}
3948+
str_l = strnlen(utf8_str, str_l);
39473949
CHECKMEM(8+str_l);
39483950
i64 = str_l;
39493951
memcpy(out+out_idx, &i64, 8);
@@ -4010,7 +4012,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
40104012
out_idx += 8;
40114013

40124014
} else if (col_types[i].type == NUMPY_BYTES) {
4013-
void *bytes = (void*)(cols[i] + j * 8);
4015+
void *bytes = (void*)(cols[i] + j * col_types[i].length);
40144016

40154017
if (bytes == NULL) {
40164018
CHECKMEM(8);
@@ -4434,7 +4436,10 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
44344436

44354437
// Get return types
44364438
n_cols = (unsigned long long)PyObject_Length(py_returns);
4437-
if (n_cols == 0) goto error;
4439+
if (n_cols == 0) {
4440+
PyErr_SetString(PyExc_ValueError, "no return values specified");
4441+
goto error;
4442+
}
44384443

44394444
returns = malloc(sizeof(int) * n_cols);
44404445
if (!returns) goto error;

singlestoredb/functions/ext/asgi.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,10 @@ async def do_func( # type: ignore
272272
return row_ids, [out]
273273

274274
# Call function on each column of data
275-
res = get_dataframe_columns(func(*[x[0] for x in cols]))
275+
if cols and cols[0]:
276+
res = get_dataframe_columns(func(*[x[0] for x in cols]))
277+
else:
278+
res = get_dataframe_columns(func())
276279

277280
# Generate row IDs
278281
row_ids = array_cls([row_ids[0]] * len(res[0]))
@@ -308,7 +311,10 @@ async def do_func( # type: ignore
308311
return row_ids, [out]
309312

310313
# Call the function with `cols` as the function parameters
311-
out = func(*[x[0] for x in cols])
314+
if cols and cols[0]:
315+
out = func(*[x[0] for x in cols])
316+
else:
317+
out = func()
312318

313319
# Multiple return values
314320
if isinstance(out, tuple):
@@ -717,6 +723,7 @@ async def __call__(
717723
func_info['colspec'], b''.join(data),
718724
),
719725
)
726+
print(func_info['returns'], out)
720727
body = output_handler['dump'](
721728
[x[1] for x in func_info['returns']], *out, # type: ignore
722729
)

singlestoredb/functions/signature.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,11 @@ def get_signature(
11151115
with_null_masks=with_null_masks,
11161116
)
11171117

1118+
# All functions have to return a value, so if none was specified try to
1119+
# insert a reasonable default that includes NULLs.
1120+
if not ret_schema:
1121+
ret_schema = [('', 'int8?', 'TINYINT NULL')]
1122+
11181123
# Generate names for fields as needed
11191124
if function_type == 'tvf' or len(ret_schema) > 1:
11201125
for i, (name, rtype, sql) in enumerate(ret_schema):
@@ -1300,7 +1305,9 @@ def signature_to_sql(
13001305
res = ret[0]['sql']
13011306
returns = f' RETURNS {res}'
13021307
else:
1303-
returns = ' RETURNS NULL'
1308+
raise ValueError(
1309+
'function signature must have a return type specified',
1310+
)
13041311

13051312
host = os.environ.get('SINGLESTOREDB_EXT_HOST', '127.0.0.1')
13061313
port = os.environ.get('SINGLESTOREDB_EXT_PORT', '8000')

singlestoredb/tests/ext_funcs/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
from singlestoredb.functions import Masked
1111
from singlestoredb.functions import MaskedNDArray
12+
from singlestoredb.functions import tvf
1213
from singlestoredb.functions import udf
1314
from singlestoredb.functions import udf_with_null_masks
1415
from singlestoredb.functions.dtypes import BIGINT
16+
from singlestoredb.functions.dtypes import BLOB
1517
from singlestoredb.functions.dtypes import DOUBLE
1618
from singlestoredb.functions.dtypes import FLOAT
1719
from singlestoredb.functions.dtypes import MEDIUMINT
@@ -480,3 +482,34 @@ def arrow_nullable_tinyint_mult_with_masks(
480482
x_data, x_nulls = x
481483
y_data, y_nulls = y
482484
return (pc.multiply(x_data, y_data), pc.or_(x_nulls, y_nulls))
485+
486+
487+
@tvf(returns=[TEXT(nullable=False, name='res')])
488+
def numpy_fixed_strings() -> npt.NDArray[np.str_]:
489+
out = np.array(
490+
[
491+
'hello',
492+
'hi there 😜',
493+
'😜 bye',
494+
], dtype=np.str_,
495+
)
496+
assert str(out.dtype) == '<U10'
497+
return out
498+
499+
500+
@tvf(returns=[BLOB(nullable=False, name='res')])
501+
def numpy_fixed_binary() -> npt.NDArray[np.bytes_]:
502+
out = np.array(
503+
[
504+
'hello'.encode('utf8'),
505+
'hi there 😜'.encode('utf8'),
506+
'😜 bye'.encode('utf8'),
507+
], dtype=np.bytes_,
508+
)
509+
assert str(out.dtype) == '|S13'
510+
return out
511+
512+
513+
@udf
514+
def no_args_no_return_value() -> None:
515+
pass

singlestoredb/tests/test_ext_func.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,3 +1193,44 @@ def _test_nullable_varchar_mult(self):
11931193
assert desc[0].name == 'res'
11941194
assert desc[0].type_code == ft.BLOB
11951195
assert desc[0].null_ok is True
1196+
1197+
def test_numpy_fixed_strings(self):
1198+
self.cur.execute('select * from numpy_fixed_strings()')
1199+
1200+
assert [tuple(x) for x in self.cur] == [
1201+
('hello',),
1202+
('hi there 😜',),
1203+
('😜 bye',),
1204+
]
1205+
1206+
desc = self.cur.description
1207+
assert len(desc) == 1
1208+
assert desc[0].name == 'res'
1209+
assert desc[0].type_code == ft.BLOB
1210+
assert desc[0].null_ok is False
1211+
1212+
def test_numpy_fixed_binary(self):
1213+
self.cur.execute('select * from numpy_fixed_binary()')
1214+
1215+
assert [tuple(x) for x in self.cur] == [
1216+
('hello'.encode('utf8') + b'\x00' * 8,),
1217+
('hi there 😜'.encode('utf8'),),
1218+
('😜 bye'.encode('utf8') + b'\x00' * 5,),
1219+
]
1220+
1221+
desc = self.cur.description
1222+
assert len(desc) == 1
1223+
assert desc[0].name == 'res'
1224+
assert desc[0].type_code == ft.BLOB
1225+
assert desc[0].null_ok is False
1226+
1227+
def test_no_args_no_return_value(self):
1228+
self.cur.execute('select no_args_no_return_value() as res')
1229+
1230+
assert [tuple(x) for x in self.cur] == [(None,)]
1231+
1232+
desc = self.cur.description
1233+
assert len(desc) == 1
1234+
assert desc[0].name == 'res'
1235+
assert desc[0].type_code == ft.TINY
1236+
assert desc[0].null_ok is True

singlestoredb/tests/test_udf.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def foo(): ...
4848

4949
# NULL return value
5050
def foo() -> None: ...
51-
assert to_sql(foo) == '`foo`() RETURNS NULL'
51+
assert to_sql(foo) == '`foo`() RETURNS TINYINT NULL'
5252

5353
# Simple return value
5454
def foo() -> int: ...
@@ -138,44 +138,44 @@ def foo(x) -> None: ...
138138

139139
# Simple parameter
140140
def foo(x: int) -> None: ...
141-
assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS NULL'
141+
assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS TINYINT NULL'
142142

143143
# Optional parameter
144144
def foo(x: Optional[int]) -> None: ...
145-
assert to_sql(foo) == '`foo`(`x` BIGINT NULL) RETURNS NULL'
145+
assert to_sql(foo) == '`foo`(`x` BIGINT NULL) RETURNS TINYINT NULL'
146146

147147
# Optional parameter
148148
def foo(x: Union[int, None]) -> None: ...
149-
assert to_sql(foo) == '`foo`(`x` BIGINT NULL) RETURNS NULL'
149+
assert to_sql(foo) == '`foo`(`x` BIGINT NULL) RETURNS TINYINT NULL'
150150

151151
# Optional multiple parameter types
152152
def foo(x: Union[int, float, None]) -> None: ...
153-
assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS NULL'
153+
assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS TINYINT NULL'
154154

155155
# Optional parameter with custom type
156156
def foo(x: Optional[B]) -> None: ...
157-
assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS NULL'
157+
assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS TINYINT NULL'
158158

159159
# Optional parameter with nested custom type
160160
def foo(x: Optional[C]) -> None: ...
161-
assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS NULL'
161+
assert to_sql(foo) == '`foo`(`x` DOUBLE NULL) RETURNS TINYINT NULL'
162162

163163
# Optional parameter with collection type
164164
def foo(x: Optional[List[str]]) -> None: ...
165-
assert to_sql(foo) == '`foo`(`x` ARRAY(TEXT NOT NULL) NULL) RETURNS NULL'
165+
assert to_sql(foo) == '`foo`(`x` ARRAY(TEXT NOT NULL) NULL) RETURNS TINYINT NULL'
166166

167167
# Optional parameter with nested collection type
168168
def foo(x: Optional[List[List[str]]]) -> None: ...
169169
assert to_sql(foo) == '`foo`(`x` ARRAY(ARRAY(TEXT NOT NULL) NOT NULL) NULL) ' \
170-
'RETURNS NULL'
170+
'RETURNS TINYINT NULL'
171171

172172
# Optional parameter with collection type with nulls
173173
def foo(x: Optional[List[Optional[str]]]) -> None: ...
174-
assert to_sql(foo) == '`foo`(`x` ARRAY(TEXT NULL) NULL) RETURNS NULL'
174+
assert to_sql(foo) == '`foo`(`x` ARRAY(TEXT NULL) NULL) RETURNS TINYINT NULL'
175175

176176
# Custom type with bound
177177
def foo(x: D) -> None: ...
178-
assert to_sql(foo) == '`foo`(`x` TEXT NOT NULL) RETURNS NULL'
178+
assert to_sql(foo) == '`foo`(`x` TEXT NOT NULL) RETURNS TINYINT NULL'
179179

180180
# Incompatible types
181181
def foo(x: Union[int, str]) -> None: ...
@@ -209,15 +209,15 @@ def test_datetimes(self):
209209

210210
# Datetime
211211
def foo(x: datetime.datetime) -> None: ...
212-
assert to_sql(foo) == '`foo`(`x` DATETIME NOT NULL) RETURNS NULL'
212+
assert to_sql(foo) == '`foo`(`x` DATETIME NOT NULL) RETURNS TINYINT NULL'
213213

214214
# Date
215215
def foo(x: datetime.date) -> None: ...
216-
assert to_sql(foo) == '`foo`(`x` DATE NOT NULL) RETURNS NULL'
216+
assert to_sql(foo) == '`foo`(`x` DATE NOT NULL) RETURNS TINYINT NULL'
217217

218218
# Time
219219
def foo(x: datetime.timedelta) -> None: ...
220-
assert to_sql(foo) == '`foo`(`x` TIME NOT NULL) RETURNS NULL'
220+
assert to_sql(foo) == '`foo`(`x` TIME NOT NULL) RETURNS TINYINT NULL'
221221

222222
# Datetime + Date
223223
def foo(x: Union[datetime.datetime, datetime.date]) -> None: ...
@@ -229,75 +229,76 @@ def test_numerics(self):
229229
# Ints
230230
#
231231
def foo(x: int) -> None: ...
232-
assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS NULL'
232+
assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS TINYINT NULL'
233233

234234
def foo(x: np.int8) -> None: ...
235-
assert to_sql(foo) == '`foo`(`x` TINYINT NOT NULL) RETURNS NULL'
235+
assert to_sql(foo) == '`foo`(`x` TINYINT NOT NULL) RETURNS TINYINT NULL'
236236

237237
def foo(x: np.int16) -> None: ...
238-
assert to_sql(foo) == '`foo`(`x` SMALLINT NOT NULL) RETURNS NULL'
238+
assert to_sql(foo) == '`foo`(`x` SMALLINT NOT NULL) RETURNS TINYINT NULL'
239239

240240
def foo(x: np.int32) -> None: ...
241-
assert to_sql(foo) == '`foo`(`x` INT NOT NULL) RETURNS NULL'
241+
assert to_sql(foo) == '`foo`(`x` INT NOT NULL) RETURNS TINYINT NULL'
242242

243243
def foo(x: np.int64) -> None: ...
244-
assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS NULL'
244+
assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) RETURNS TINYINT NULL'
245245

246246
#
247247
# Unsigned ints
248248
#
249249
def foo(x: np.uint8) -> None: ...
250-
assert to_sql(foo) == '`foo`(`x` TINYINT UNSIGNED NOT NULL) RETURNS NULL'
250+
assert to_sql(foo) == '`foo`(`x` TINYINT UNSIGNED NOT NULL) RETURNS TINYINT NULL'
251251

252252
def foo(x: np.uint16) -> None: ...
253-
assert to_sql(foo) == '`foo`(`x` SMALLINT UNSIGNED NOT NULL) RETURNS NULL'
253+
assert to_sql(foo) == '`foo`(`x` SMALLINT UNSIGNED NOT NULL) RETURNS TINYINT NULL'
254254

255255
def foo(x: np.uint32) -> None: ...
256-
assert to_sql(foo) == '`foo`(`x` INT UNSIGNED NOT NULL) RETURNS NULL'
256+
assert to_sql(foo) == '`foo`(`x` INT UNSIGNED NOT NULL) RETURNS TINYINT NULL'
257257

258258
def foo(x: np.uint64) -> None: ...
259-
assert to_sql(foo) == '`foo`(`x` BIGINT UNSIGNED NOT NULL) RETURNS NULL'
259+
assert to_sql(foo) == '`foo`(`x` BIGINT UNSIGNED NOT NULL) RETURNS TINYINT NULL'
260260

261261
#
262262
# Floats
263263
#
264264
def foo(x: float) -> None: ...
265-
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS NULL'
265+
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL'
266266

267267
def foo(x: np.float32) -> None: ...
268-
assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS NULL'
268+
assert to_sql(foo) == '`foo`(`x` FLOAT NOT NULL) RETURNS TINYINT NULL'
269269

270270
def foo(x: np.float64) -> None: ...
271-
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS NULL'
271+
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL'
272272

273273
#
274274
# Type collapsing
275275
#
276276
def foo(x: Union[np.int8, np.int16]) -> None: ...
277-
assert to_sql(foo) == '`foo`(`x` SMALLINT NOT NULL) RETURNS NULL'
277+
assert to_sql(foo) == '`foo`(`x` SMALLINT NOT NULL) RETURNS TINYINT NULL'
278278

279279
def foo(x: Union[np.int64, np.double]) -> None: ...
280-
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS NULL'
280+
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL'
281281

282282
def foo(x: Union[int, float]) -> None: ...
283-
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS NULL'
283+
assert to_sql(foo) == '`foo`(`x` DOUBLE NOT NULL) RETURNS TINYINT NULL'
284284

285285
def test_positional_and_keyword_parameters(self):
286286
# Keyword only
287287
def foo(x: int = 100) -> None: ...
288-
assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL DEFAULT 100) RETURNS NULL'
288+
assert to_sql(foo) == \
289+
'`foo`(`x` BIGINT NOT NULL DEFAULT 100) RETURNS TINYINT NULL'
289290

290291
# Multiple keywords
291292
def foo(x: int = 100, y: float = 3.14) -> None: ...
292293
assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL DEFAULT 100, ' \
293-
'`y` DOUBLE NOT NULL DEFAULT 3.14e0) RETURNS NULL'
294+
'`y` DOUBLE NOT NULL DEFAULT 3.14e0) RETURNS TINYINT NULL'
294295

295296
# Keywords and positional
296297
def foo(a: str, b: str, x: int = 100, y: float = 3.14) -> None: ...
297298
assert to_sql(foo) == '`foo`(`a` TEXT NOT NULL, ' \
298299
'`b` TEXT NOT NULL, ' \
299300
'`x` BIGINT NOT NULL DEFAULT 100, ' \
300-
'`y` DOUBLE NOT NULL DEFAULT 3.14e0) RETURNS NULL'
301+
'`y` DOUBLE NOT NULL DEFAULT 3.14e0) RETURNS TINYINT NULL'
301302

302303
# Variable positional
303304
def foo(*args: int) -> None: ...

0 commit comments

Comments
 (0)