Skip to content

Commit b32ac13

Browse files
committed
Add vector type
1 parent 36f2b13 commit b32ac13

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

singlestoredb/functions/dtypes.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,3 +1744,51 @@ def ARRAY(
17441744
out = SQLString(f'ARRAY({dtype})' + _modifiers(nullable=nullable))
17451745
out.name = name
17461746
return out
1747+
1748+
1749+
F32 = 'F32'
1750+
F64 = 'F64'
1751+
I8 = 'I8'
1752+
I16 = 'I16'
1753+
I32 = 'I32'
1754+
I64 = 'I64'
1755+
1756+
1757+
def VECTOR(
1758+
length: int,
1759+
element_type: str = F32,
1760+
*,
1761+
nullable: bool = True,
1762+
default: Optional[bytes] = None,
1763+
name: Optional[str] = None,
1764+
) -> SQLString:
1765+
"""
1766+
VECTOR type specification.
1767+
1768+
Parameters
1769+
----------
1770+
n : int
1771+
Number of elements in vector
1772+
element_type : str, optional
1773+
Type of the elements in the vector:
1774+
F32, F64, I8, I16, I32, I64
1775+
nullable : bool, optional
1776+
Can the value be NULL?
1777+
default : str, optional
1778+
Default value
1779+
name : str, optional
1780+
Name of the column / parameter
1781+
1782+
Returns
1783+
-------
1784+
SQLString
1785+
1786+
"""
1787+
out = f'VECTOR({int(length)}, {element_type})'
1788+
out = SQLString(
1789+
out + _modifiers(
1790+
nullable=nullable, default=default,
1791+
),
1792+
)
1793+
out.name = name
1794+
return out

singlestoredb/functions/signature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def get_schema(
788788
elif utils.is_dataframe(spec) or utils.is_vector(spec):
789789
if not overrides:
790790
raise TypeError(
791-
'type overrides must be specified for DataFrames / Tables',
791+
'type overrides must be specified for vectors or DataFrames / Tables',
792792
)
793793

794794
# Unsuported types

singlestoredb/tests/test_udf.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,3 +694,19 @@ def test_dtypes(self):
694694

695695
assert dt.ARRAY(dt.INT) == 'ARRAY(INT NULL) NULL'
696696
assert dt.ARRAY(dt.INT, nullable=False) == 'ARRAY(INT NULL) NOT NULL'
697+
698+
assert dt.VECTOR(8) == 'VECTOR(8, F32) NULL'
699+
assert dt.VECTOR(8, dt.F32) == 'VECTOR(8, F32) NULL'
700+
assert dt.VECTOR(8, dt.F64) == 'VECTOR(8, F64) NULL'
701+
assert dt.VECTOR(8, dt.I8) == 'VECTOR(8, I8) NULL'
702+
assert dt.VECTOR(8, dt.I16) == 'VECTOR(8, I16) NULL'
703+
assert dt.VECTOR(8, dt.I32) == 'VECTOR(8, I32) NULL'
704+
assert dt.VECTOR(8, dt.I64) == 'VECTOR(8, I64) NULL'
705+
706+
assert dt.VECTOR(8, nullable=False) == 'VECTOR(8, F32) NOT NULL'
707+
assert dt.VECTOR(8, dt.F32, nullable=False) == 'VECTOR(8, F32) NOT NULL'
708+
assert dt.VECTOR(8, dt.F64, nullable=False) == 'VECTOR(8, F64) NOT NULL'
709+
assert dt.VECTOR(8, dt.I8, nullable=False) == 'VECTOR(8, I8) NOT NULL'
710+
assert dt.VECTOR(8, dt.I16, nullable=False) == 'VECTOR(8, I16) NOT NULL'
711+
assert dt.VECTOR(8, dt.I32, nullable=False) == 'VECTOR(8, I32) NOT NULL'
712+
assert dt.VECTOR(8, dt.I64, nullable=False) == 'VECTOR(8, I64) NOT NULL'

0 commit comments

Comments
 (0)