Skip to content

Commit 0d1785e

Browse files
authored
Merge branch 'release/current' into main
2 parents f5bb330 + c2a050b commit 0d1785e

10 files changed

Lines changed: 349 additions & 12 deletions

File tree

pyoaev/configuration/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
11
from .configuration import Configuration
2+
from .settings_loader import (
3+
BaseConfigModel,
4+
ConfigLoaderCollector,
5+
ConfigLoaderOAEV,
6+
SettingsLoader,
7+
)
28

3-
__all__ = ["Configuration"]
9+
__all__ = [
10+
"Configuration",
11+
"ConfigLoaderOAEV",
12+
"ConfigLoaderCollector",
13+
"SettingsLoader",
14+
"BaseConfigModel",
15+
]

pyoaev/configuration/configuration.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
import yaml
66
from pydantic import BaseModel, Field
7+
from pydantic_settings import BaseSettings
78

9+
from pyoaev.configuration.connector_config_schema_generator import (
10+
ConnectorConfigSchemaGenerator,
11+
)
812
from pyoaev.configuration.sources import DictionarySource, EnvironmentSource
913

1014
CONFIGURATION_TYPES = str | int | bool | Any | None
@@ -111,6 +115,7 @@ def __init__(
111115
config_hints: Dict[str, dict | str],
112116
config_values: dict = None,
113117
config_file_path: str = os.path.join(os.curdir, "config.yml"),
118+
config_base_model: BaseSettings = None,
114119
):
115120
self.__config_hints = {
116121
key: (
@@ -129,6 +134,8 @@ def __init__(
129134

130135
self.__config_values = (config_values or {}) | file_contents
131136

137+
self.__base_model = config_base_model
138+
132139
def get(self, config_key: str) -> CONFIGURATION_TYPES:
133140
"""Gets the value pointed to by the configuration key. If the key is defined
134141
with actual hints (as opposed to a discrete value), it will use those hints to
@@ -146,7 +153,12 @@ def get(self, config_key: str) -> CONFIGURATION_TYPES:
146153
return None
147154

148155
return self.__process_value_to_type(
149-
config.data or self.__dig_config_sources_for_key(config), config.is_number
156+
(
157+
self.__dig_config_sources_for_key(config)
158+
if config.data is None
159+
else config.data
160+
),
161+
config.is_number,
150162
)
151163

152164
def set(self, config_key: str, value: CONFIGURATION_TYPES):
@@ -164,6 +176,19 @@ def set(self, config_key: str, value: CONFIGURATION_TYPES):
164176
else:
165177
self.__config_hints[config_key].data = value
166178

179+
def schema(self):
180+
"""
181+
Generates the complete connector schema using a custom schema generator compatible with Pydantic.
182+
Isolate custom class generator, Pydantic expects a class, not an instance
183+
Always subclass GenerateJsonSchema and pass the class to Pydantic, not an instance
184+
:return: The generated connector schema as a dictionary.
185+
"""
186+
return self.__base_model.model_json_schema(
187+
by_alias=False,
188+
schema_generator=ConnectorConfigSchemaGenerator,
189+
mode="validation",
190+
)
191+
167192
@staticmethod
168193
def __process_value_to_type(value: CONFIGURATION_TYPES, is_number_hint: bool):
169194
if value is None:
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
## ADAPTED FROM https://github.com/OpenCTI-Platform/connectors/blob/5c8cf1235f62f5651c9c08d0b67f1bd182662c8a/shared/tools/composer/generate_connectors_config_schemas/generate_connector_config_json_schema.py.sample
2+
3+
from copy import deepcopy
4+
from typing import override
5+
6+
from pydantic.json_schema import GenerateJsonSchema
7+
8+
# attributes filtered from the connector configuration before generating the manifest
9+
__FILTERED_ATTRIBUTES__ = [
10+
# connector id is generated
11+
"CONNECTOR_ID",
12+
]
13+
14+
15+
class ConnectorConfigSchemaGenerator(GenerateJsonSchema):
16+
@staticmethod
17+
def dereference_schema(schema_with_refs):
18+
"""Return a new schema with all internal $ref resolved."""
19+
20+
def _resolve(schema, root):
21+
if isinstance(schema, dict):
22+
if "$ref" in schema:
23+
ref_path = schema["$ref"]
24+
if ref_path.startswith("#/$defs/"):
25+
def_name = ref_path.split("/")[-1]
26+
# Deep copy to avoid mutating $defs
27+
resolved = deepcopy(root["$defs"][def_name])
28+
return _resolve(resolved, root)
29+
else:
30+
raise ValueError(f"Unsupported ref format: {ref_path}")
31+
else:
32+
return {
33+
schema_key: _resolve(schema_value, root)
34+
for schema_key, schema_value in schema.items()
35+
}
36+
elif isinstance(schema, list):
37+
return [_resolve(item, root) for item in schema]
38+
else:
39+
return schema
40+
41+
return _resolve(deepcopy(schema_with_refs), schema_with_refs)
42+
43+
@staticmethod
44+
def flatten_config_loader_schema(root_schema: dict):
45+
"""
46+
Flatten config loader schema so all config vars are described at root level.
47+
48+
:param root_schema: Original schema.
49+
:return: Flatten schema.
50+
"""
51+
flat_json_schema = {
52+
"$schema": root_schema["$schema"],
53+
"$id": root_schema["$id"],
54+
"type": "object",
55+
"properties": {},
56+
"required": [],
57+
"additionalProperties": root_schema.get("additionalProperties", True),
58+
}
59+
60+
for (
61+
config_loader_namespace_name,
62+
config_loader_namespace_schema,
63+
) in root_schema["properties"].items():
64+
config_schema = config_loader_namespace_schema.get("properties", {})
65+
required_config_vars = config_loader_namespace_schema.get("required", [])
66+
67+
for config_var_name, config_var_schema in config_schema.items():
68+
property_name = (
69+
f"{config_loader_namespace_name.upper()}_{config_var_name.upper()}"
70+
)
71+
72+
config_var_schema.pop("title", None)
73+
74+
flat_json_schema["properties"][property_name] = config_var_schema
75+
76+
if config_var_name in required_config_vars:
77+
flat_json_schema["required"].append(property_name)
78+
79+
return flat_json_schema
80+
81+
@staticmethod
82+
def filter_schema(schema):
83+
for filtered_attribute in __FILTERED_ATTRIBUTES__:
84+
if filtered_attribute in schema["properties"]:
85+
del schema["properties"][filtered_attribute]
86+
schema.update(
87+
{
88+
"required": [
89+
item
90+
for item in schema["required"]
91+
if item != filtered_attribute
92+
]
93+
}
94+
)
95+
96+
return schema
97+
98+
@override
99+
def generate(self, schema, mode="validation"):
100+
json_schema = super().generate(schema, mode=mode)
101+
102+
json_schema["$schema"] = self.schema_dialect
103+
json_schema["$id"] = "config.schema.json"
104+
dereferenced_schema = self.dereference_schema(json_schema)
105+
flattened_schema = self.flatten_config_loader_schema(dereferenced_schema)
106+
return self.filter_schema(flattened_schema)
107+
108+
@override
109+
def nullable_schema(self, schema):
110+
"""Generates a JSON schema that matches a schema that allows null values.
111+
112+
Args:
113+
schema: The core schema.
114+
115+
Returns:
116+
The generated JSON schema.
117+
118+
Notes:
119+
This method overrides `GenerateJsonSchema.nullable_schema` to generate schemas without `anyOf` keyword.
120+
"""
121+
null_schema = {"type": "null"}
122+
inner_json_schema = self.generate_inner(schema["schema"])
123+
124+
if inner_json_schema == null_schema:
125+
return null_schema
126+
else:
127+
return inner_json_schema
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import os
2+
from abc import ABC
3+
from datetime import timedelta
4+
from pathlib import Path
5+
from typing import Annotated, Literal
6+
7+
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, PlainSerializer
8+
from pydantic_settings import (
9+
BaseSettings,
10+
DotEnvSettingsSource,
11+
PydanticBaseSettingsSource,
12+
SettingsConfigDict,
13+
YamlConfigSettingsSource,
14+
)
15+
16+
17+
class BaseConfigModel(BaseModel, ABC):
18+
"""Base class for global config models
19+
To prevent attributes from being modified after initialization.
20+
"""
21+
22+
model_config = ConfigDict(extra="allow", frozen=True, validate_default=True)
23+
24+
25+
class SettingsLoader(BaseSettings):
26+
model_config = SettingsConfigDict(
27+
frozen=True,
28+
extra="allow",
29+
env_nested_delimiter="_",
30+
env_nested_max_split=1,
31+
enable_decoding=False,
32+
)
33+
34+
@classmethod
35+
def settings_customise_sources(
36+
cls,
37+
settings_cls: type[BaseSettings],
38+
init_settings: PydanticBaseSettingsSource,
39+
env_settings: PydanticBaseSettingsSource,
40+
dotenv_settings: PydanticBaseSettingsSource,
41+
file_secret_settings: PydanticBaseSettingsSource,
42+
) -> tuple[PydanticBaseSettingsSource, ...]:
43+
"""Customise the sources of settings for the connector.
44+
45+
This method is called by the Pydantic BaseSettings class to determine the order of sources.
46+
The configuration come in this order either from:
47+
1. Environment variables
48+
2. YAML file
49+
3. .env file
50+
4. Default values
51+
52+
The variables loading order will remain the same as in `pycti.get_config_variable()`:
53+
1. If a config.yml file is found, the order will be: `ENV VAR` → config.yml → default value
54+
2. If a .env file is found, the order will be: `ENV VAR` → .env → default value
55+
"""
56+
_main_path = os.curdir
57+
58+
settings_cls.model_config["env_file"] = f"{_main_path}/../.env"
59+
60+
if not settings_cls.model_config["yaml_file"]:
61+
if Path(f"{_main_path}/config.yml").is_file():
62+
settings_cls.model_config["yaml_file"] = f"{_main_path}/config.yml"
63+
if Path(f"{_main_path}/../config.yml").is_file():
64+
settings_cls.model_config["yaml_file"] = f"{_main_path}/../config.yml"
65+
66+
if Path(settings_cls.model_config["yaml_file"] or "").is_file(): # type: ignore
67+
return (
68+
env_settings,
69+
YamlConfigSettingsSource(settings_cls),
70+
)
71+
if Path(settings_cls.model_config["env_file"] or "").is_file(): # type: ignore
72+
return (
73+
env_settings,
74+
DotEnvSettingsSource(settings_cls),
75+
)
76+
return (env_settings,)
77+
78+
79+
LogLevelToLower = Annotated[
80+
Literal["debug", "info", "warn", "error"],
81+
PlainSerializer(lambda v: "".join(v), return_type=str),
82+
]
83+
84+
HttpUrlToString = Annotated[HttpUrl, PlainSerializer(str, return_type=str)]
85+
TimedeltaInSeconds = Annotated[
86+
timedelta, PlainSerializer(lambda v: int(v.total_seconds()), return_type=int)
87+
]
88+
89+
90+
class ConfigLoaderOAEV(BaseConfigModel):
91+
"""OpenAEV/OpenAEV platform configuration settings.
92+
93+
Contains URL and authentication token for connecting to the OpenAEV platform.
94+
"""
95+
96+
url: HttpUrlToString = Field(
97+
description="The OpenAEV platform URL.",
98+
)
99+
token: str = Field(
100+
description="The token for the OpenAEV platform.",
101+
)
102+
103+
104+
class ConfigLoaderCollector(BaseConfigModel):
105+
"""Base collector configuration settings.
106+
107+
Contains common collector settings including identification, logging,
108+
scheduling, and platform information.
109+
"""
110+
111+
id: str = Field(description="ID of the collector.")
112+
113+
name: str = Field(description="Name of the collector")
114+
115+
log_level: LogLevelToLower | None = Field(
116+
default="error",
117+
description="Determines the verbosity of the logs.",
118+
)
119+
period: timedelta | None = Field(
120+
default=timedelta(minutes=1),
121+
description="Duration between two scheduled runs of the collector (ISO 8601 format).",
122+
)
123+
icon_filepath: str | None = Field(
124+
description="Path to the icon file of the collector.",
125+
)

pyoaev/contracts/contract_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ class ContractConfig:
120120
color_light: str
121121

122122

123+
@dataclass
124+
class Domain:
125+
domain_id: str
126+
domain_name: str
127+
domain_color: str
128+
domain_created_at: str
129+
domain_updated_at: str
130+
131+
123132
@dataclass
124133
class Contract:
125134
contract_id: str
@@ -141,6 +150,7 @@ class Contract:
141150
is_atomic_testing: bool = True
142151
platforms: List[str] = field(default_factory=list)
143152
external_id: str = None
153+
domains: List[Domain] = None
144154

145155
def add_attack_pattern(self, var: str):
146156
self.contract_attack_patterns_external_ids.append(var)
@@ -163,6 +173,7 @@ def to_contract_add_input(self, source_id: str):
163173
"contract_content": json.dumps(self, cls=utils.EnhancedJSONEncoder),
164174
"is_atomic_testing": self.is_atomic_testing,
165175
"contract_platforms": self.platforms,
176+
"contract_domains": self.domains,
166177
}
167178

168179
def to_contract_update_input(self):
@@ -174,6 +185,7 @@ def to_contract_update_input(self):
174185
"contract_content": json.dumps(self, cls=utils.EnhancedJSONEncoder),
175186
"is_atomic_testing": self.is_atomic_testing,
176187
"contract_platforms": self.platforms,
188+
"contract_domains": self.domains,
177189
}
178190

179191

@@ -203,6 +215,7 @@ def prepare_contracts(contracts):
203215
"contract_attack_patterns_external_ids": c.contract_attack_patterns_external_ids,
204216
"contract_content": json.dumps(c, cls=utils.EnhancedJSONEncoder),
205217
"contract_platforms": c.platforms,
218+
"contract_domains": c.domains,
206219
},
207220
contracts,
208221
)

0 commit comments

Comments
 (0)