-
Notifications
You must be signed in to change notification settings - Fork 99
Expand file tree
/
Copy pathschema.py
More file actions
302 lines (251 loc) · 10.4 KB
/
schema.py
File metadata and controls
302 lines (251 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
import sqlalchemy as sa
from marshmallow.fields import Field
from marshmallow.schema import Schema, SchemaMeta, SchemaOpts
from .convert import ModelConverter
from .exceptions import IncorrectSchemaTypeError
from .load_instance_mixin import LoadInstanceMixin, _ModelType
if TYPE_CHECKING:
from sqlalchemy.ext.declarative import DeclarativeMeta
# This isn't really a field; it's a placeholder for the metaclass.
# This should be considered private API.
class SQLAlchemyAutoField(Field):
def __init__(
self,
*,
column_name: str | None = None,
model: type[DeclarativeMeta] | None = None,
table: sa.Table | None = None,
field_kwargs: dict[str, Any],
):
super().__init__()
if model and table:
raise ValueError("Cannot pass both `model` and `table` options.")
self.column_name = column_name
self.model = model
self.table = table
self.field_kwargs = field_kwargs
def create_field(
self,
schema_opts: SQLAlchemySchemaOpts,
column_name: str,
converter: ModelConverter,
):
model = self.model or schema_opts.model
if model:
return converter.field_for(model, column_name, **self.field_kwargs)
table = self.table if self.table is not None else schema_opts.table
column = getattr(cast("sa.Table", table).columns, column_name)
return converter.column2field(column, **self.field_kwargs)
# This field should never be bound to a schema.
# If this method is called, it's probably because the schema is not a SQLAlchemySchema.
def _bind_to_schema(self, field_name: str, parent: Schema | Field) -> None:
raise IncorrectSchemaTypeError(
f"Cannot bind SQLAlchemyAutoField. Make sure that {parent} is a SQLAlchemySchema or SQLAlchemyAutoSchema."
)
class SQLAlchemySchemaOpts(LoadInstanceMixin.Opts, SchemaOpts):
"""Options class for `SQLAlchemySchema`.
Adds the following options:
- ``model``: The SQLAlchemy model to generate the `Schema` from (mutually exclusive with ``table``).
- ``table``: The SQLAlchemy table to generate the `Schema` from (mutually exclusive with ``model``).
- ``load_instance``: Whether to load model instances.
- ``sqla_session``: SQLAlchemy session to be used for deserialization.
This is only needed when ``load_instance`` is `True`. You can also pass a session to the Schema's `load` method.
- ``transient``: Whether to load model instances in a transient state (effectively ignoring the session).
Only relevant when ``load_instance`` is `True`.
- ``model_converter``: `ModelConverter` class to use for converting the SQLAlchemy model to marshmallow fields.
"""
table: sa.Table | None
model_converter: type[ModelConverter]
def __init__(self, meta, *args, **kwargs):
super().__init__(meta, *args, **kwargs)
self.table = getattr(meta, "table", None)
if self.model is not None and self.table is not None:
raise ValueError("Cannot set both `model` and `table` options.")
self.model_converter = getattr(meta, "model_converter", ModelConverter)
class SQLAlchemyAutoSchemaOpts(SQLAlchemySchemaOpts):
"""Options class for `SQLAlchemyAutoSchema`.
Has the same options as `SQLAlchemySchemaOpts`, with the addition of:
- ``include_fk``: Whether to include foreign fields; defaults to `False`.
- ``include_relationships``: Whether to include relationships; defaults to `False`.
"""
include_fk: bool
include_relationships: bool
def __init__(self, meta, *args, **kwargs):
super().__init__(meta, *args, **kwargs)
self.include_fk = getattr(meta, "include_fk", False)
self.include_relationships = getattr(meta, "include_relationships", False)
if self.table is not None and self.include_relationships:
raise ValueError("Cannot set `table` and `include_relationships = True`.")
class SQLAlchemySchemaMeta(SchemaMeta):
@classmethod
def get_declared_fields(
mcs,
klass,
cls_fields: list[tuple[str, Field]],
inherited_fields: list[tuple[str, Field]],
dict_cls: type[dict] = dict,
) -> dict[str, Field]:
opts = klass.opts
Converter: type[ModelConverter] = opts.model_converter
converter = Converter(schema_cls=klass)
fields = super().get_declared_fields(
klass,
cls_fields,
# Filter out fields generated from foreign key columns
# if include_fk is set to False in the options
mcs._maybe_filter_foreign_keys(inherited_fields, opts=opts, klass=klass),
dict_cls,
)
fields.update(mcs.get_declared_sqla_fields(fields, converter, opts, dict_cls))
fields.update(mcs.get_auto_fields(fields, converter, opts, dict_cls))
return fields
@classmethod
def get_declared_sqla_fields(
mcs,
base_fields: dict[str, Field],
converter: ModelConverter,
opts: Any,
dict_cls: type[dict],
) -> dict[str, Field]:
return {}
@classmethod
def get_auto_fields(
mcs,
fields: dict[str, Field],
converter: ModelConverter,
opts: Any,
dict_cls: type[dict],
) -> dict[str, Field]:
return dict_cls(
{
field_name: field.create_field(
opts, field.column_name or field_name, converter
)
for field_name, field in fields.items()
if isinstance(field, SQLAlchemyAutoField)
and field_name not in opts.exclude
}
)
@staticmethod
def _maybe_filter_foreign_keys(
fields: list[tuple[str, Field]],
*,
opts: SQLAlchemySchemaOpts,
klass: SchemaMeta,
) -> list[tuple[str, Field]]:
if opts.model is not None or opts.table is not None:
if not hasattr(opts, "include_fk") or opts.include_fk is True:
return fields
foreign_keys = {
column.key
for column in sa.inspect(opts.model or opts.table).columns # type: ignore[union-attr]
if column.foreign_keys
}
# Collect fields explicitly declared in non-AutoSchema bases.
# XXX: Avoid issubclass(base, Schema) because it causes quadratic
# ABCMeta cache growth with many schema classes (#665).
declared_fields: set[str] = set()
for base in klass.__mro__:
if base is object:
break
opts_cls = getattr(base, "OPTIONS_CLASS", None)
if opts_cls is not None and issubclass(
opts_cls, SQLAlchemyAutoSchemaOpts
):
continue
base_declared = base.__dict__.get("_declared_fields")
if base_declared:
declared_fields.update(base_declared.keys())
return [
(name, field)
for name, field in fields
if name not in foreign_keys or name in declared_fields
]
return fields
class SQLAlchemyAutoSchemaMeta(SQLAlchemySchemaMeta):
@classmethod
def get_declared_sqla_fields(
cls, base_fields, converter: ModelConverter, opts, dict_cls
):
fields = dict_cls()
if opts.table is not None:
fields.update(
converter.fields_for_table(
opts.table,
fields=opts.fields,
exclude=opts.exclude,
include_fk=opts.include_fk,
base_fields=base_fields,
dict_cls=dict_cls,
)
)
elif opts.model is not None:
fields.update(
converter.fields_for_model(
opts.model,
fields=opts.fields,
exclude=opts.exclude,
include_fk=opts.include_fk,
include_relationships=opts.include_relationships,
base_fields=base_fields,
dict_cls=dict_cls,
)
)
return fields
class SQLAlchemySchema(
LoadInstanceMixin.Schema[_ModelType], Schema, metaclass=SQLAlchemySchemaMeta
):
"""Schema for a SQLAlchemy model or table.
Use together with `auto_field` to generate fields from columns.
Example: ::
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
from mymodels import User
class UserSchema(SQLAlchemySchema):
class Meta:
model = User
id = auto_field()
created_at = auto_field(dump_only=True)
name = auto_field()
"""
OPTIONS_CLASS = SQLAlchemySchemaOpts
class SQLAlchemyAutoSchema(
SQLAlchemySchema[_ModelType], metaclass=SQLAlchemyAutoSchemaMeta
):
"""Schema that automatically generates fields from the columns of
a SQLAlchemy model or table.
Example: ::
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema, auto_field
from mymodels import User
class UserSchema(SQLAlchemyAutoSchema):
class Meta:
model = User
# OR
# table = User.__table__
created_at = auto_field(dump_only=True)
"""
OPTIONS_CLASS = SQLAlchemyAutoSchemaOpts
def auto_field(
column_name: str | None = None,
*,
model: type[DeclarativeMeta] | None = None,
table: sa.Table | None = None,
# TODO: add type annotations for **kwargs
**kwargs,
) -> SQLAlchemyAutoField:
"""Mark a field to autogenerate from a model or table.
:param column_name: Name of the column to generate the field from.
If ``None``, matches the field name. If ``attribute`` is unspecified,
``attribute`` will be set to the same value as ``column_name``.
:param model: Model to generate the field from.
If ``None``, uses ``model`` specified on ``class Meta``.
:param table: Table to generate the field from.
If ``None``, uses ``table`` specified on ``class Meta``.
:param kwargs: Field argument overrides.
"""
if column_name is not None:
kwargs.setdefault("attribute", column_name)
return SQLAlchemyAutoField(
column_name=column_name, model=model, table=table, field_kwargs=kwargs
)