77from typing import List
88from typing import Optional
99
10+ from oidcmsg .configure import Base
11+
1012from oidcrp .logging import configure_logging
1113from oidcrp .util import load_yaml_config
1214from oidcrp .util import lower_or_upper
1618except ImportError :
1719 from cryptojwt import rndstr as rnd_token
1820
19- DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key' , 'server_cert' , 'filename' , 'template_dir' ,
20- 'private_path' , 'public_path' , 'db_file' ]
21-
22-
23- def add_base_path (conf : dict , base_path : str , file_attributes : List [str ]):
24- for key , val in conf .items ():
25- if key in file_attributes :
26- if val .startswith ("/" ):
27- continue
28- elif val == "" :
29- conf [key ] = "./" + val
30- else :
31- conf [key ] = os .path .join (base_path , val )
32- if isinstance (val , dict ):
33- conf [key ] = add_base_path (val , base_path , file_attributes )
34-
35- return conf
36-
37-
38- def set_domain_and_port (conf : dict , uris : List [str ], domain : str , port : int ):
39- for key , val in conf .items ():
40- if key in uris :
41- if not val :
42- continue
43-
44- if isinstance (val , list ):
45- _new = [v .format (domain = domain , port = port ) for v in val ]
46- else :
47- _new = val .format (domain = domain , port = port )
48- conf [key ] = _new
49- elif isinstance (val , dict ):
50- conf [key ] = set_domain_and_port (val , uris , domain , port )
51- return conf
52-
53-
54- class Base :
55- """ Configuration base class """
56-
57- def __init__ (self ,
58- conf : Dict ,
59- base_path : str = '' ,
60- file_attributes : Optional [List [str ]] = None ,
61- dir_attributes : Optional [List [str ]] = None
62- ):
63-
64- if file_attributes is None :
65- file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES
66-
67- if base_path and file_attributes :
68- # this adds a base path to all paths in the configuration
69- add_base_path (conf , base_path , file_attributes )
70-
71- def __getitem__ (self , item ):
72- if item in self .__dict__ :
73- return self .__dict__ [item ]
74- else :
75- raise KeyError
76-
77- def get (self , item , default = None ):
78- return getattr (self , item , default )
79-
80- def __contains__ (self , item ):
81- return item in self .__dict__
82-
83- def items (self ):
84- for key in self .__dict__ :
85- if key .startswith ('__' ) and key .endswith ('__' ):
86- continue
87- yield key , getattr (self , key )
88-
89- def extend (self , entity_conf , conf , base_path , file_attributes , domain , port ):
90- for econf in entity_conf :
91- _path = econf .get ("path" )
92- _cnf = conf
93- if _path :
94- for step in _path :
95- _cnf = _cnf [step ]
96- _attr = econf ["attr" ]
97- _cls = econf ["class" ]
98- setattr (self , _attr ,
99- _cls (_cnf , base_path = base_path , file_attributes = file_attributes ,
100- domain = domain , port = port ))
101-
102-
10321URIS = [
10422 "redirect_uris" , 'post_logout_redirect_uris' , 'frontchannel_logout_uri' ,
10523 'backchannel_logout_uri' , 'issuer' , 'base_url' ]
@@ -113,23 +31,17 @@ def __init__(self,
11331 domain : Optional [str ] = "127.0.0.1" ,
11432 port : Optional [int ] = 80 ,
11533 file_attributes : Optional [List [str ]] = None ,
34+ dir_attributes : Optional [List [str ]] = None ,
11635 ):
11736
118- Base .__init__ (self , conf , base_path = base_path , file_attributes = file_attributes )
119-
120- _keys_conf = lower_or_upper (conf , 'rp_keys' )
121- if _keys_conf is None :
122- _keys_conf = lower_or_upper (conf , 'oidc_keys' ) # legacy
123-
124- self .keys = _keys_conf
125-
126- if not domain :
127- domain = conf .get ("domain" , "127.0.0.1" )
37+ Base .__init__ (self , conf ,
38+ base_path = base_path ,
39+ domain = domain ,
40+ port = port ,
41+ file_attributes = file_attributes ,
42+ dir_attributes = dir_attributes )
12843
129- if not port :
130- port = conf .get ("port" , 80 )
131-
132- conf = set_domain_and_port (conf , URIS , domain , port )
44+ self .key_conf = lower_or_upper (conf , 'rp_keys' ) or lower_or_upper (conf , 'oidc_keys' )
13345 self .clients = lower_or_upper (conf , "clients" )
13446
13547 hash_seed = lower_or_upper (conf , 'hash_seed' )
@@ -167,13 +79,6 @@ def __init__(self,
16779
16880 self .web_conf = lower_or_upper (conf , "webserver" )
16981
170- # entity info
171- if not domain :
172- domain = conf .get ("domain" , "127.0.0.1" )
173-
174- if not port :
175- port = conf .get ("port" , 80 )
176-
17782 if entity_conf :
17883 self .extend (entity_conf = entity_conf , conf = conf , base_path = base_path ,
17984 file_attributes = file_attributes , domain = domain , port = port )
@@ -184,6 +89,7 @@ def create_from_config_file(cls,
18489 base_path : Optional [str ] = '' ,
18590 entity_conf : Optional [List [dict ]] = None ,
18691 file_attributes : Optional [List [str ]] = None ,
92+ dir_attributes : Optional [List [str ]] = None ,
18793 domain : Optional [str ] = "" ,
18894 port : Optional [int ] = 0 ):
18995 if filename .endswith (".yaml" ):
@@ -203,4 +109,4 @@ def create_from_config_file(cls,
203109 return cls (_cnf ,
204110 entity_conf = entity_conf ,
205111 base_path = base_path , file_attributes = file_attributes ,
206- domain = domain , port = port )
112+ domain = domain , port = port , dir_attributes = dir_attributes )
0 commit comments