Skip to content

Commit 459f8f9

Browse files
committed
Experimental support for Pydantic models
1 parent 2882aae commit 459f8f9

File tree

5 files changed

+219
-3
lines changed

5 files changed

+219
-3
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@ repos:
4343
rev: v1.6.1
4444
hooks:
4545
- id: mypy
46-
additional_dependencies: [types-requests]
46+
additional_dependencies: [types-requests, types-python-dateutil]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
build
2+
dateutil
23
parsimonious
34
PyJWT
45
requests

singlestoredb/connection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class DataFrame(object): # type: ignore
3131
def itertuples(self, *args: Any, **kwargs: Any) -> None:
3232
pass
3333

34+
from . import dbobjects
3435
from . import auth
3536
from . import exceptions
3637
from .config import get_option
@@ -1288,6 +1289,8 @@ def show(self) -> ShowAccessor:
12881289
"""Access server properties managed by the SHOW statement."""
12891290
return ShowAccessor(self)
12901291

1292+
dbs = property(dbobjects.dbs)
1293+
12911294

12921295
#
12931296
# NOTE: When adding parameters to this function, you should always

singlestoredb/dbobjects.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import datetime
2+
import math
3+
from typing import Any
4+
from typing import Callable
5+
from typing import Dict
6+
from typing import List
7+
from typing import Optional
8+
from typing import Tuple
9+
from typing import Type
10+
11+
from dateutil import parser
12+
from pydantic import BaseModel
13+
from pydantic import create_model
14+
from pydantic import Field
15+
16+
from .connection import Connection
17+
18+
19+
DB_TYPEMAP: Dict[str, Tuple[Type[Any], Dict[str, Any], Callable[[Any], Any]]] = dict(
20+
char=(str, {}, str),
21+
varchar=(str, dict(max_length=21844), str),
22+
tinyint=(int, dict(ge=-128, le=127), int),
23+
smallint=(int, dict(ge=-32768, le=32767), int),
24+
mediumint=(int, dict(ge=-8388608, le=8388607), int),
25+
int=(int, dict(ge=-2147483648, le=2147483647), int),
26+
bigint=(int, dict(ge=-2**63, le=(2**63)-1), int),
27+
binary=(bytes, {}, bytes),
28+
varbinary=(bytes, dict(max_length=65533), bytes),
29+
tinyblob=(bytes, dict(max_length=215), bytes),
30+
mediumblob=(bytes, dict(max_length=16777216), bytes),
31+
blob=(bytes, dict(max_length=65535), bytes),
32+
longblob=(bytes, dict(max_length=4194304000), bytes),
33+
tinytext=(str, dict(max_length=255), str),
34+
mediumtext=(str, dict(max_length=16777216), str),
35+
text=(str, dict(max_length=65535), str),
36+
longtext=(str, dict(max_length=4194304000), str),
37+
float=(float, {}, float),
38+
double=(float, {}, float),
39+
datetime=(datetime.datetime, {}, parser.parse), # noqa
40+
datetime=(datetime.datetime, {}, parser.parse), # noqa
41+
time=(datetime.timedelta, {}, lambda x: parser.parse(x).time()), # noqa
42+
date=(datetime.date, {}, lambda x: parser.parse(x).date()), # noqa
43+
timestamp=(datetime.datetime, {}, parser.parse), # noqa
44+
year=(int, dict(ge=1901, le=2155), int),
45+
)
46+
47+
48+
class TableBaseModel(BaseModel):
49+
pass
50+
51+
52+
class TableRowBaseModel(BaseModel):
53+
pass
54+
55+
56+
class Table:
57+
58+
def __init__(self, conn: Connection, database: str, name: str) -> None:
59+
self.connection = conn
60+
self.database = database
61+
self.name = name
62+
self.schema = self._schema()
63+
64+
def insert(self, model: TableBaseModel) -> int:
65+
subs = ', '.join([f'%({x["COLUMN_NAME"]})s' for x in self.schema])
66+
names = ', '.join([f'`{x["COLUMN_NAME"]}`' for x in self.schema])
67+
query = f'INSERT INTO `{self.database}`.`{self.name}`({names}) VALUES ({subs})'
68+
with self.connection.cursor() as cur:
69+
if isinstance(model, TableBaseModel):
70+
out = cur.executemany(query, model.model_dump()['rows'])
71+
elif isinstance(model, TableRowBaseModel):
72+
out = cur.executemany(query, model.model_dump())
73+
else:
74+
raise TypeError('Unrecognized parameter type for insert')
75+
return out
76+
77+
def insert_completions(self, client: Any, **kwargs: Any) -> int:
78+
kwargs['response_model'] = self.table_model()
79+
kwargs['messages'] = list(kwargs.get('messages', []))
80+
81+
has_system_msg = False
82+
has_user_msg = False
83+
for msg in kwargs['messages']:
84+
if msg['role'] == 'system':
85+
has_system_msg = True
86+
elif msg['role'] == 'user':
87+
has_user_msg = True
88+
89+
if not has_system_msg:
90+
kwargs['messages'].insert(
91+
0, dict(role='system', content='You are a helpful assistant'),
92+
)
93+
94+
kwargs['messages'].insert(
95+
1, dict(
96+
role='assistant', content=self._table_info()['TABLE_COMMENT'],
97+
),
98+
)
99+
100+
if not has_user_msg:
101+
kwargs['messages'].append(
102+
dict(role='user', content=self._table_info()['TABLE_COMMENT']),
103+
)
104+
105+
return self.insert(client.create(**kwargs))
106+
107+
def _schema(self) -> List[Dict[str, Any]]:
108+
query = '''
109+
SELECT * FROM information_schema.COLUMNS
110+
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s ORDER BY ORDINAL_POSITION
111+
'''
112+
out = []
113+
with self.connection.cursor() as cur:
114+
cur.execute(query, (self.database, self.name))
115+
names = [x.name for x in cur.description or []]
116+
for row in cur:
117+
out.append({k: v for k, v in zip(names, row)})
118+
return out
119+
120+
def _table_info(self) -> Dict[str, Any]:
121+
query = '''
122+
SELECT * FROM information_schema.TABLES
123+
WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s
124+
'''
125+
with self.connection.cursor() as cur:
126+
cur.execute(query, (self.database, self.name))
127+
names = [x.name for x in cur.description or []]
128+
for row in cur:
129+
return {k: v for k, v in zip(names, row)}
130+
return {}
131+
132+
def table_model(self) -> Type[TableBaseModel]:
133+
tbl_info = self._table_info()
134+
desc = tbl_info.get('TABLE_COMMENT') or 'rows of data in the table'
135+
row_model = self.row_model()
136+
tbl_model = create_model(
137+
'TableModel',
138+
rows=(List[row_model], Field(description=desc, default=[])), # type: ignore
139+
__base__=(TableBaseModel,),
140+
)
141+
return tbl_model
142+
143+
def row_model(self) -> Type[TableRowBaseModel]:
144+
row_model = create_model(
145+
'RowModel',
146+
**dict([self._get_model_field(x) for x in self.schema]),
147+
__base__=(TableRowBaseModel,),
148+
)
149+
return row_model
150+
151+
def _get_model_field(
152+
self,
153+
info: Dict[str, Any],
154+
) -> Tuple[str, Tuple[Type[Any], Field]]:
155+
is_required = 'N' in info['IS_NULLABLE']
156+
157+
dtype, dtype_params, dtype_conv = DB_TYPEMAP[info['DATA_TYPE']]
158+
if not is_required:
159+
dtype = Optional[dtype] # type: ignore
160+
161+
kwargs = dtype_params.copy()
162+
163+
if info['COLUMN_COMMENT'].strip():
164+
kwargs['description'] = info['COLUMN_COMMENT'].strip()
165+
166+
if info['COLUMN_DEFAULT']:
167+
kwargs['default'] = dtype_conv(info['COLUMN_DEFAULT'])
168+
elif not is_required:
169+
kwargs['default'] = None
170+
171+
max_length = info['CHARACTER_MAXIMUM_LENGTH']
172+
if max_length is not None and not math.isnan(max_length):
173+
kwargs['max_length'] = int(max_length)
174+
175+
if dtype in ['decimal']:
176+
max_digits = info['NUMERIC_PRECISION']
177+
if max_digits is not None and not math.isnan(max_digits):
178+
kwargs['max_digits'] = int(max_digits)
179+
180+
decimal_places = info['NUMERIC_SCALE']
181+
if decimal_places is not None and not math.isnan(decimal_places):
182+
kwargs['decimal_places'] = int(decimal_places)
183+
184+
return (str(info['COLUMN_NAME']), (dtype, Field(**kwargs)))
185+
186+
187+
class Database:
188+
189+
def __init__(self, conn: Connection, name: str) -> None:
190+
self.connection = conn
191+
self.name = name
192+
193+
@property
194+
def tables(self) -> Dict[str, Table]:
195+
with self.connection.cursor() as cur:
196+
cur.execute(f'SHOW TABLES IN `{self.name}`')
197+
return {k[0]: Table(self.connection, self.name, k[0]) for k in cur}
198+
199+
200+
def dbs(self: Connection) -> Dict[str, Database]:
201+
with self.cursor() as cur:
202+
cur.execute('SHOW DATABASES')
203+
return {k[0]: Database(self, k[0]) for k in cur}

singlestoredb/functions/ext/asgi.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,21 @@ def get_func_names(funcs: str) -> List[Tuple[str, str]]:
139139

140140
def as_tuple(x: Any) -> Any:
141141
if hasattr(x, 'model_fields'):
142-
return tuple(x.model_fields.values())
142+
return tuple(x.model_dump().values())
143143
if dataclasses.is_dataclass(x):
144144
return dataclasses.astuple(x)
145145
return x
146146

147147

148+
def as_list_of_tuples(x: Any) -> Any:
149+
if isinstance(x, (list, tuple)) and len(x) > 0:
150+
if hasattr(x[0], 'model_fields'):
151+
return [tuple(y.model_dump().values()) for y in x]
152+
if dataclasses.is_dataclass(x[0]):
153+
return [dataclasses.astuple(y) for y in x]
154+
return x
155+
156+
148157
def make_func(
149158
name: str,
150159
func: Callable[..., Any],
@@ -183,7 +192,7 @@ async def do_func(
183192
out_ids: List[int] = []
184193
out = []
185194
for i, res in zip(row_ids, func_map(func, rows)):
186-
out.extend(as_tuple(res))
195+
out.extend(as_list_of_tuples(res))
187196
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
188197
return out_ids, out
189198

0 commit comments

Comments
 (0)