Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions django/db/backends/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,6 @@ def distinct_sql(self, fields, params):
else:
return ["DISTINCT"], []

def fetch_returned_insert_columns(self, cursor, returning_params):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the newly created data.
"""
return cursor.fetchone()

def force_group_by(self):
"""
Return a GROUP BY clause to use with a HAVING clause when no grouping
Expand Down Expand Up @@ -358,13 +351,31 @@ def process_clob(self, value):
"""
return value

def return_insert_columns(self, fields):
def returning_columns(self, fields):
"""
For backends that support returning columns as part of an insert query,
return the SQL and params to append to the INSERT query. The returned
fragment should contain a format string to hold the appropriate column.
For backends that support returning columns as part of an insert or
update query, return the SQL and params to append to the query.
The returned fragment should contain a format string to hold the
appropriate column.
"""
pass
if not fields:
return "", ()
columns = [
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()

def fetch_returned_rows(self, cursor, returning_params):
"""
Given a cursor object for a DML query with a RETURNING statement,
return the selected returning rows of tuples.
"""
return cursor.fetchall()

def compiler(self, compiler_name):
"""
Expand Down
21 changes: 0 additions & 21 deletions django/db/backends/mysql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,6 @@ def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
else:
return f"TIME({sql})", params

def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the tuple of returned data.
"""
return cursor.fetchall()

def format_for_duration_arithmetic(self, sql):
return "INTERVAL %s MICROSECOND" % sql

Expand Down Expand Up @@ -182,20 +175,6 @@ def quote_name(self, name):
return name # Quoting once is enough.
return "`%s`" % name

def return_insert_columns(self, fields):
# MySQL doesn't support an INSERT...RETURNING statement.
if not fields:
return "", ()
columns = [
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()

def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if not tables:
return []
Expand Down
48 changes: 22 additions & 26 deletions django/db/backends/oracle/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from django.utils.regex_helper import _lazy_re_compile

from .base import Database
from .utils import BulkInsertMapper, InsertVar, Oracle_datetime
from .utils import BoundVar, BulkInsertMapper, Oracle_datetime


class DatabaseOperations(BaseDatabaseOperations):
Expand Down Expand Up @@ -298,12 +298,27 @@ def convert_empty_bytes(value, expression, connection):
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"

def fetch_returned_insert_columns(self, cursor, returning_params):
columns = []
for param in returning_params:
value = param.get_value()
columns.append(value[0])
return tuple(columns)
def returning_columns(self, fields):
if not fields:
return "", ()
field_names = []
params = []
for field in fields:
field_names.append(
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
)
params.append(BoundVar(field))
return "RETURNING %s INTO %s" % (
", ".join(field_names),
", ".join(["%s"] * len(params)),
), tuple(params)

def fetch_returned_rows(self, cursor, returning_params):
return list(zip(*(param.get_value() for param in returning_params)))

def no_limit_value(self):
return None
Expand Down Expand Up @@ -391,25 +406,6 @@ def regex_lookup(self, lookup_type):
match_option = "'i'"
return "REGEXP_LIKE(%%s, %%s, %s)" % match_option

def return_insert_columns(self, fields):
if not fields:
return "", ()
field_names = []
params = []
for field in fields:
field_names.append(
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
)
params.append(InsertVar(field))
return "RETURNING %s INTO %s" % (
", ".join(field_names),
", ".join(["%s"] * len(params)),
), tuple(params)

def __foreign_key_constraints(self, table_name, recursive):
with self.connection.cursor() as cursor:
if recursive:
Expand Down
2 changes: 1 addition & 1 deletion django/db/backends/oracle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .base import Database


class InsertVar:
class BoundVar:
"""
A late-binding cursor variable that can be passed to Cursor.execute
as a parameter, in order to receive the id of the row created by an
Expand Down
20 changes: 0 additions & 20 deletions django/db/backends/postgresql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,6 @@ def bulk_insert_sql(self, fields, placeholder_rows):
return f"SELECT * FROM {placeholder_rows}"
return super().bulk_insert_sql(fields, placeholder_rows)

def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the tuple of returned data.
"""
return cursor.fetchall()

def lookup_cast(self, lookup_type, internal_type=None):
lookup = "%s"
# Cast text lookups to text to allow things like filter(x__contains=4)
Expand Down Expand Up @@ -324,19 +317,6 @@ def last_executed_query(self, cursor, sql, params):
return cursor.query.decode()
return None

def return_insert_columns(self, fields):
if not fields:
return "", ()
columns = [
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()

if is_psycopg3:

def adapt_integerfield_value(self, value, internal_type):
Expand Down
21 changes: 0 additions & 21 deletions django/db/backends/sqlite3/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,6 @@ def date_extract_sql(self, lookup_type, sql, params):
"""
return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params)

def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the list of returned data.
"""
return cursor.fetchall()

def format_for_duration_arithmetic(self, sql):
"""Do nothing since formatting is handled in the custom function."""
return sql
Expand Down Expand Up @@ -399,20 +392,6 @@ def insert_statement(self, on_conflict=None):
return "INSERT OR IGNORE INTO"
return super().insert_statement(on_conflict=on_conflict)

def return_insert_columns(self, fields):
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
if not fields:
return "", ()
columns = [
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()

def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if (
on_conflict == OnConflict.UPDATE
Expand Down
20 changes: 8 additions & 12 deletions django/db/models/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,7 +1890,7 @@ def as_sql(self):
result.append(on_conflict_suffix_sql)
# Skip empty r_sql to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
r_sql, self.returning_params = self.connection.ops.return_insert_columns(
r_sql, self.returning_params = self.connection.ops.returning_columns(
self.returning_fields
)
if r_sql:
Expand Down Expand Up @@ -1925,20 +1925,16 @@ def execute_sql(self, returning_fields=None):
cursor.execute(sql, params)
if not self.returning_fields:
return []
obj_len = len(self.query.objs)
if (
self.connection.features.can_return_rows_from_bulk_insert
and len(self.query.objs) > 1
and obj_len > 1
) or (
self.connection.features.can_return_columns_from_insert and obj_len == 1
):
rows = self.connection.ops.fetch_returned_insert_rows(cursor)
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
elif self.connection.features.can_return_columns_from_insert:
assert len(self.query.objs) == 1
rows = [
self.connection.ops.fetch_returned_insert_columns(
cursor,
self.returning_params,
)
]
rows = self.connection.ops.fetch_returned_rows(
cursor, self.returning_params
)
cols = [field.get_col(opts.db_table) for field in self.returning_fields]
elif returning_fields and isinstance(
returning_field := returning_fields[0], AutoField
Expand Down
23 changes: 10 additions & 13 deletions django/middleware/csp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from http import HTTPStatus

from django.conf import settings
from django.utils.csp import CSP, LazyNonce, build_policy
from django.utils.deprecation import MiddlewareMixin
Expand All @@ -14,22 +12,21 @@ def process_request(self, request):
request._csp_nonce = LazyNonce()

def process_response(self, request, response):
# In DEBUG mode, exclude CSP headers for specific status codes that
# trigger the debug view.
exempted_status_codes = {
HTTPStatus.NOT_FOUND,
HTTPStatus.INTERNAL_SERVER_ERROR,
}
if settings.DEBUG and response.status_code in exempted_status_codes:
return response

nonce = get_nonce(request)

sentinel = object()
if (csp_config := getattr(response, "_csp_config", sentinel)) is sentinel:
csp_config = settings.SECURE_CSP
if (csp_ro_config := getattr(response, "_csp_ro_config", sentinel)) is sentinel:
csp_ro_config = settings.SECURE_CSP_REPORT_ONLY

for header, config in [
(CSP.HEADER_ENFORCE, settings.SECURE_CSP),
(CSP.HEADER_REPORT_ONLY, settings.SECURE_CSP_REPORT_ONLY),
(CSP.HEADER_ENFORCE, csp_config),
(CSP.HEADER_REPORT_ONLY, csp_ro_config),
]:
# If headers are already set on the response, don't overwrite them.
# This allows for views to set their own CSP headers as needed.
# An empty config means CSP headers are not added to the response.
if config and header not in response:
response.headers[str(header)] = build_policy(config, nonce)

Expand Down
5 changes: 5 additions & 0 deletions django/views/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from django.utils.module_loading import import_string
from django.utils.regex_helper import _lazy_re_compile
from django.utils.version import get_docs_version
from django.views.decorators.csp import csp_override, csp_report_only_override
from django.views.decorators.debug import coroutine_functions_to_sensitive_variables

# Minimal Django templates engine to render the error templates
Expand Down Expand Up @@ -59,6 +60,8 @@ def __repr__(self):
return repr(self._wrapped)


@csp_override({})
@csp_report_only_override({})
def technical_500_response(request, exc_type, exc_value, tb, status_code=500):
"""
Create a technical server error response. The last three arguments are
Expand Down Expand Up @@ -606,6 +609,8 @@ def get_exception_traceback_frames(self, exc_value, tb):
tb = tb.tb_next


@csp_override({})
@csp_report_only_override({})
def technical_404_response(request, exception):
"""Create a technical 404 error response. `exception` is the Http404."""
try:
Expand Down
39 changes: 39 additions & 0 deletions django/views/decorators/csp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from functools import wraps

from asgiref.sync import iscoroutinefunction


def _make_csp_decorator(config_attr_name, config_attr_value):
"""General CSP override decorator factory."""

if not isinstance(config_attr_value, dict):
raise TypeError("CSP config should be a mapping.")

def decorator(view_func):
@wraps(view_func)
async def _wrapped_async_view(request, *args, **kwargs):
response = await view_func(request, *args, **kwargs)
setattr(response, config_attr_name, config_attr_value)
return response

@wraps(view_func)
def _wrapped_sync_view(request, *args, **kwargs):
response = view_func(request, *args, **kwargs)
setattr(response, config_attr_name, config_attr_value)
return response

if iscoroutinefunction(view_func):
return _wrapped_async_view
return _wrapped_sync_view

return decorator


def csp_override(config):
"""Override the Content-Security-Policy header for a view."""
return _make_csp_decorator("_csp_config", config)


def csp_report_only_override(config):
"""Override the Content-Security-Policy-Report-Only header for a view."""
return _make_csp_decorator("_csp_ro_config", config)
Loading
Loading