1+ from __future__ import annotations
2+
13import inspect
24
35from copy import deepcopy
6+ from typing import TYPE_CHECKING , Any , Generic , TypeGuard , TypeVar , overload
7+
8+ if TYPE_CHECKING :
9+ from collections .abc import Callable , Generator
10+ from typing_extensions import Self
11+
12+ from webob .response import Response as BaseResponse
13+
14+ _T = TypeVar ("_T" )
415
516SELF = "'self'"
617UNSAFE_INLINE = "'unsafe-inline'"
920STRICT_DYNAMIC = "'strict-dynamic'"
1021
1122
12- class Directive :
23+ class Directive ( Generic [ _T ]) :
1324 """Descriptor for the management and rendering of CSP directives.
1425
1526 Uses types to do some basic sanity checking. This does not ensure
@@ -20,13 +31,19 @@ class Directive:
2031
2132 """
2233
23- def __init__ (self , name , type , default , render ):
34+ def __init__ (
35+ self ,
36+ name : str ,
37+ type : type [_T ],
38+ default : Callable [[], _T ],
39+ render : Callable [[_T ], str | None ],
40+ ) -> None :
2441 self .name = name
2542 self .type = type
2643 self .default = default
2744 self .renderer = render
2845
29- def render (self , instance ) :
46+ def render (self , instance : ContentSecurityPolicy ) -> str | None :
3047 if self .name not in instance .__dict__ :
3148 return None
3249
@@ -35,49 +52,55 @@ def render(self, instance):
3552
3653 return self .renderer (instance .__dict__ [self .name ])
3754
38- def __get__ (self , instance , cls ):
55+ @overload
56+ def __get__ (self , instance : None , cls : type [ContentSecurityPolicy ]) -> Self : ...
57+ @overload
58+ def __get__ (
59+ self , instance : ContentSecurityPolicy , cls : type [ContentSecurityPolicy ]
60+ ) -> _T : ...
61+
62+ def __get__ (
63+ self , instance : ContentSecurityPolicy | None , cls : type [ContentSecurityPolicy ]
64+ ) -> _T | Self :
3965 if instance is None :
4066 return self
4167
4268 if self .name not in instance .__dict__ :
4369 instance .__dict__ [self .name ] = self .default ()
4470
45- return instance .__dict__ [self .name ]
71+ return instance .__dict__ [self .name ] # type: ignore[no-any-return]
4672
47- def __set__ (self , instance , value ) :
73+ def __set__ (self , instance : ContentSecurityPolicy , value : _T ) -> None :
4874 if not isinstance (value , self .type ):
4975 raise TypeError (f"Expected type { self .type } " )
5076
5177 instance .__dict__ [self .name ] = value
5278
5379
54- class SetDirective (Directive ):
55- def __init__ (self , name ):
56- parent = super ()
57- parent .__init__ (name , type = set , default = set , render = render_set )
80+ class SetDirective (Directive [set [str ]]):
81+ def __init__ (self , name : str ) -> None :
82+ super ().__init__ (name , type = set , default = set , render = render_set )
5883
5984
60- class SingleValueDirective (Directive ):
61- def __init__ (self , name ):
62- parent = super ()
63- parent .__init__ (name , type = str , default = str , render = str )
85+ class SingleValueDirective (Directive [str ]):
86+ def __init__ (self , name : str ) -> None :
87+ super ().__init__ (name , type = str , default = str , render = str )
6488
6589
66- class BooleanDirective (Directive ):
67- def __init__ (self , name ):
68- parent = super ()
69- parent .__init__ (name , type = bool , default = bool , render = render_bool )
90+ class BooleanDirective (Directive [bool ]):
91+ def __init__ (self , name : str ) -> None :
92+ super ().__init__ (name , type = bool , default = bool , render = render_bool )
7093
7194
72- def is_directive (obj ) :
95+ def is_directive (obj : object ) -> TypeGuard [ Directive [ Any ]] :
7396 return isinstance (obj , Directive )
7497
7598
76- def render_set (value ) :
99+ def render_set (value : set [ str ]) -> str :
77100 return " " .join (sorted (value ))
78101
79102
80- def render_bool (value ) :
103+ def render_bool (value : bool ) -> str | None :
81104 return "" if value else None
82105
83106
@@ -103,39 +126,52 @@ class ContentSecurityPolicy:
103126 """
104127
105128 # Fetch directives
106- child_src = SetDirective ("child-src" )
107- connect_src = SetDirective ("connect-src" )
108- default_src = SetDirective ("default-src" )
109- font_src = SetDirective ("font-src" )
110- frame_src = SetDirective ("frame-src" )
111- img_src = SetDirective ("img-src" )
112- manifest_src = SetDirective ("manifest-src" )
113- media_src = SetDirective ("media-src" )
114- object_src = SetDirective ("object-src" )
115- script_src = SetDirective ("script-src" )
116- style_src = SetDirective ("style-src" )
117- worker_src = SetDirective ("worker-src" )
129+ child_src : SetDirective = SetDirective ("child-src" )
130+ connect_src : SetDirective = SetDirective ("connect-src" )
131+ default_src : SetDirective = SetDirective ("default-src" )
132+ font_src : SetDirective = SetDirective ("font-src" )
133+ frame_src : SetDirective = SetDirective ("frame-src" )
134+ img_src : SetDirective = SetDirective ("img-src" )
135+ manifest_src : SetDirective = SetDirective ("manifest-src" )
136+ media_src : SetDirective = SetDirective ("media-src" )
137+ object_src : SetDirective = SetDirective ("object-src" )
138+ script_src : SetDirective = SetDirective ("script-src" )
139+ style_src : SetDirective = SetDirective ("style-src" )
140+ worker_src : SetDirective = SetDirective ("worker-src" )
118141
119142 # Document directives
120- base_uri = SetDirective ("base-uri" )
121- plugin_types = SetDirective ("plugin-types" )
122- sandbox = SingleValueDirective ("sandbox" )
123- disown_opener = BooleanDirective ("disown-opener" )
143+ base_uri : SetDirective = SetDirective ("base-uri" )
144+ plugin_types : SetDirective = SetDirective ("plugin-types" )
145+ sandbox : SingleValueDirective = SingleValueDirective ("sandbox" )
146+ disown_opener : BooleanDirective = BooleanDirective ("disown-opener" )
124147
125148 # Navigation directives
126- form_action = SetDirective ("form-action" )
127- frame_ancestors = SetDirective ("frame-ancestors" )
149+ form_action : SetDirective = SetDirective ("form-action" )
150+ frame_ancestors : SetDirective = SetDirective ("frame-ancestors" )
128151
129152 # Reporting directives
130- report_uri = SingleValueDirective ("report-uri" )
131- report_to = SingleValueDirective ("report-to" )
153+ report_uri : SingleValueDirective = SingleValueDirective ("report-uri" )
154+ report_to : SingleValueDirective = SingleValueDirective ("report-to" )
132155
133156 # Other directives
134- block_all_mixed_content = BooleanDirective ("block-all-mixed-content" )
135- require_sri_for = SingleValueDirective ("require-sri-for" )
136- upgrade_insecure_requeists = BooleanDirective ("upgrade-insecure-requests" )
137-
138- def __init__ (self , report_only = False , ** directives ):
157+ block_all_mixed_content : BooleanDirective = BooleanDirective (
158+ "block-all-mixed-content"
159+ )
160+ require_sri_for : SingleValueDirective = SingleValueDirective ("require-sri-for" )
161+ upgrade_insecure_requeists : BooleanDirective = BooleanDirective (
162+ "upgrade-insecure-requests"
163+ )
164+
165+ def __init__ (
166+ self ,
167+ report_only : bool = False ,
168+ # NOTE: This is both a little too lax and a little too strict, but
169+ # it doesn't seem worth defining a TypedDict, to get better
170+ # type checking on this, this will work for most cases and
171+ # is not the recommended style of defining the directives
172+ # anyways.
173+ ** directives : set [str ] | str | bool ,
174+ ) -> None :
139175 self .report_only = report_only
140176
141177 for directive in directives :
@@ -144,32 +180,35 @@ def __init__(self, report_only=False, **directives):
144180 assert hasattr (self , name )
145181 setattr (self , name , directives [directive ])
146182
147- def copy (self ):
183+ def copy (self ) -> Self :
148184 policy = self .__class__ ()
149185 policy .__dict__ = deepcopy (self .__dict__ )
150186
151187 return policy
152188
153189 @property
154- def directives (self ):
190+ def directives (self ) -> Generator [ Directive [ Any ]] :
155191 for name , value in inspect .getmembers (self .__class__ , is_directive ):
156192 yield value
157193
158194 @property
159- def text (self ):
160- values = ((d .name , d .render (self )) for d in self .directives )
161- values = ((name , text ) for name , text in values if text is not None )
195+ def text (self ) -> str :
196+ values = (
197+ (d .name , text )
198+ for d in self .directives
199+ if (text := d .render (self )) is not None
200+ )
162201
163202 return ";" .join (" " .join (v ).strip () for v in values )
164203
165204 @property
166- def header_name (self ):
205+ def header_name (self ) -> str :
167206 if self .report_only :
168207 return "Content-Security-Policy-Report-Only"
169208 else :
170209 return "Content-Security-Policy"
171210
172- def apply (self , response ) :
211+ def apply (self , response : BaseResponse ) -> None :
173212 text = self .text
174213
175214 if text :
0 commit comments