Skip to content

Commit 2c9dfdf

Browse files
committed
feat(config): add explicit validation for required configuration fields
1 parent 477160e commit 2c9dfdf

2 files changed

Lines changed: 44 additions & 4 deletions

File tree

qlib/config.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,33 @@ class QSettings(BaseSettings):
6262

6363
class Config:
6464
def __init__(self, default_conf):
65-
self.__dict__["_default_config"] = copy.deepcopy(default_conf) # avoiding conflicts with __getattr__
65+
self.__dict__["_default_config"] = copy.deepcopy(default_conf)
6666
self.reset()
6767

68+
def validate(self):
69+
errors = []
70+
71+
if not self.get("provider_uri"):
72+
errors.append(
73+
"provider_uri must be set (e.g. ~/.qlib/qlib_data or a valid path)"
74+
)
75+
76+
if not self.get("region"):
77+
errors.append(
78+
"region must be specified (e.g. 'cn', 'us')"
79+
)
80+
81+
if errors:
82+
raise ValueError(
83+
"Invalid Qlib configuration:\n- " + "\n- ".join(errors)
84+
)
85+
6886
def __getitem__(self, key):
6987
return self.__dict__["_config"][key]
7088

7189
def __getattr__(self, attr):
7290
if attr in self.__dict__["_config"]:
7391
return self.__dict__["_config"][attr]
74-
7592
raise AttributeError(f"No such `{attr}` in self._config")
7693

7794
def get(self, key, default=None):
@@ -109,14 +126,20 @@ def set_conf_from_C(self, config_c):
109126

110127
@staticmethod
111128
def register_from_C(config, skip_register=True):
112-
from .utils import set_log_with_config # pylint: disable=C0415
129+
from .utils import set_log_with_config
113130

114131
if C.registered and skip_register:
115132
return
116133

134+
117135
C.set_conf_from_C(config)
136+
137+
138+
C.validate()
139+
118140
if C.logging_config:
119141
set_log_with_config(C.logging_config)
142+
120143
C.register()
121144

122145

@@ -523,4 +546,4 @@ def registered(self):
523546

524547

525548
# global config
526-
C = QlibConfig(_default_config)
549+
C = QlibConfig(_default_config)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
from qlib.config import C, Config
3+
4+
5+
def test_missing_provider_uri_raises():
6+
7+
default_conf = {
8+
"provider_uri": None,
9+
"region": "us"
10+
}
11+
12+
cfg = Config(default_conf)
13+
14+
with pytest.raises(ValueError) as exc:
15+
cfg.validate()
16+
17+
assert "provider_uri must be set" in str(exc.value)

0 commit comments

Comments
 (0)