diff --git a/pyoaev/configuration/__init__.py b/pyoaev/configuration/__init__.py index f9bbec1..a4741f9 100644 --- a/pyoaev/configuration/__init__.py +++ b/pyoaev/configuration/__init__.py @@ -1,3 +1,15 @@ from .configuration import Configuration +from .settings_loader import ( + BaseConfigModel, + ConfigLoaderCollector, + ConfigLoaderOAEV, + SettingsLoader, +) -__all__ = ["Configuration"] +__all__ = [ + "Configuration", + "ConfigLoaderOAEV", + "ConfigLoaderCollector", + "SettingsLoader", + "BaseConfigModel", +] diff --git a/pyoaev/configuration/configuration.py b/pyoaev/configuration/configuration.py index 7134955..ae7af77 100644 --- a/pyoaev/configuration/configuration.py +++ b/pyoaev/configuration/configuration.py @@ -4,7 +4,11 @@ import yaml from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings +from pyoaev.configuration.connector_config_schema_generator import ( + ConnectorConfigSchemaGenerator, +) from pyoaev.configuration.sources import DictionarySource, EnvironmentSource CONFIGURATION_TYPES = str | int | bool | Any | None @@ -111,6 +115,7 @@ def __init__( config_hints: Dict[str, dict | str], config_values: dict = None, config_file_path: str = os.path.join(os.curdir, "config.yml"), + config_base_model: BaseSettings = None, ): self.__config_hints = { key: ( @@ -129,6 +134,8 @@ def __init__( self.__config_values = (config_values or {}) | file_contents + self.__base_model = config_base_model + def get(self, config_key: str) -> CONFIGURATION_TYPES: """Gets the value pointed to by the configuration key. If the key is defined with actual hints (as opposed to a discrete value), it will use those hints to @@ -169,6 +176,19 @@ def set(self, config_key: str, value: CONFIGURATION_TYPES): else: self.__config_hints[config_key].data = value + def schema(self): + """ + Generates the complete connector schema using a custom schema generator compatible with Pydantic. + Isolate custom class generator, Pydantic expects a class, not an instance + Always subclass GenerateJsonSchema and pass the class to Pydantic, not an instance + :return: The generated connector schema as a dictionary. + """ + return self.__base_model.model_json_schema( + by_alias=False, + schema_generator=ConnectorConfigSchemaGenerator, + mode="validation", + ) + @staticmethod def __process_value_to_type(value: CONFIGURATION_TYPES, is_number_hint: bool): if value is None: diff --git a/pyoaev/configuration/connector_config_schema_generator.py b/pyoaev/configuration/connector_config_schema_generator.py new file mode 100644 index 0000000..3c8ac43 --- /dev/null +++ b/pyoaev/configuration/connector_config_schema_generator.py @@ -0,0 +1,127 @@ +## ADAPTED FROM https://github.com/OpenCTI-Platform/connectors/blob/5c8cf1235f62f5651c9c08d0b67f1bd182662c8a/shared/tools/composer/generate_connectors_config_schemas/generate_connector_config_json_schema.py.sample + +from copy import deepcopy +from typing import override + +from pydantic.json_schema import GenerateJsonSchema + +# attributes filtered from the connector configuration before generating the manifest +__FILTERED_ATTRIBUTES__ = [ + # connector id is generated + "CONNECTOR_ID", +] + + +class ConnectorConfigSchemaGenerator(GenerateJsonSchema): + @staticmethod + def dereference_schema(schema_with_refs): + """Return a new schema with all internal $ref resolved.""" + + def _resolve(schema, root): + if isinstance(schema, dict): + if "$ref" in schema: + ref_path = schema["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] + # Deep copy to avoid mutating $defs + resolved = deepcopy(root["$defs"][def_name]) + return _resolve(resolved, root) + else: + raise ValueError(f"Unsupported ref format: {ref_path}") + else: + return { + schema_key: _resolve(schema_value, root) + for schema_key, schema_value in schema.items() + } + elif isinstance(schema, list): + return [_resolve(item, root) for item in schema] + else: + return schema + + return _resolve(deepcopy(schema_with_refs), schema_with_refs) + + @staticmethod + def flatten_config_loader_schema(root_schema: dict): + """ + Flatten config loader schema so all config vars are described at root level. + + :param root_schema: Original schema. + :return: Flatten schema. + """ + flat_json_schema = { + "$schema": root_schema["$schema"], + "$id": root_schema["$id"], + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": root_schema.get("additionalProperties", True), + } + + for ( + config_loader_namespace_name, + config_loader_namespace_schema, + ) in root_schema["properties"].items(): + config_schema = config_loader_namespace_schema.get("properties", {}) + required_config_vars = config_loader_namespace_schema.get("required", []) + + for config_var_name, config_var_schema in config_schema.items(): + property_name = ( + f"{config_loader_namespace_name.upper()}_{config_var_name.upper()}" + ) + + config_var_schema.pop("title", None) + + flat_json_schema["properties"][property_name] = config_var_schema + + if config_var_name in required_config_vars: + flat_json_schema["required"].append(property_name) + + return flat_json_schema + + @staticmethod + def filter_schema(schema): + for filtered_attribute in __FILTERED_ATTRIBUTES__: + if filtered_attribute in schema["properties"]: + del schema["properties"][filtered_attribute] + schema.update( + { + "required": [ + item + for item in schema["required"] + if item != filtered_attribute + ] + } + ) + + return schema + + @override + def generate(self, schema, mode="validation"): + json_schema = super().generate(schema, mode=mode) + + json_schema["$schema"] = self.schema_dialect + json_schema["$id"] = "config.schema.json" + dereferenced_schema = self.dereference_schema(json_schema) + flattened_schema = self.flatten_config_loader_schema(dereferenced_schema) + return self.filter_schema(flattened_schema) + + @override + def nullable_schema(self, schema): + """Generates a JSON schema that matches a schema that allows null values. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + + Notes: + This method overrides `GenerateJsonSchema.nullable_schema` to generate schemas without `anyOf` keyword. + """ + null_schema = {"type": "null"} + inner_json_schema = self.generate_inner(schema["schema"]) + + if inner_json_schema == null_schema: + return null_schema + else: + return inner_json_schema diff --git a/pyoaev/configuration/settings_loader.py b/pyoaev/configuration/settings_loader.py new file mode 100644 index 0000000..787840e --- /dev/null +++ b/pyoaev/configuration/settings_loader.py @@ -0,0 +1,129 @@ +import os +from abc import ABC +from datetime import timedelta +from pathlib import Path +from typing import Annotated, Literal + +from pydantic import BaseModel, ConfigDict, Field, HttpUrl, PlainSerializer +from pydantic_settings import ( + BaseSettings, + DotEnvSettingsSource, + PydanticBaseSettingsSource, + SettingsConfigDict, + YamlConfigSettingsSource, +) + + +class BaseConfigModel(BaseModel, ABC): + """Base class for global config models + To prevent attributes from being modified after initialization. + """ + + model_config = ConfigDict(extra="allow", frozen=True, validate_default=True) + + +class SettingsLoader(BaseSettings): + model_config = SettingsConfigDict( + frozen=True, + extra="allow", + env_nested_delimiter="_", + env_nested_max_split=1, + enable_decoding=False, + ) + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + """Customise the sources of settings for the connector. + + This method is called by the Pydantic BaseSettings class to determine the order of sources. + The configuration come in this order either from: + 1. Environment variables + 2. YAML file + 3. .env file + 4. Default values + + The variables loading order will remain the same as in `pycti.get_config_variable()`: + 1. If a config.yml file is found, the order will be: `ENV VAR` → config.yml → default value + 2. If a .env file is found, the order will be: `ENV VAR` → .env → default value + """ + _main_path = os.curdir + + settings_cls.model_config["env_file"] = f"{_main_path}/../.env" + + if not settings_cls.model_config["yaml_file"]: + if Path(f"{_main_path}/config.yml").is_file(): + settings_cls.model_config["yaml_file"] = f"{_main_path}/config.yml" + if Path(f"{_main_path}/../config.yml").is_file(): + settings_cls.model_config["yaml_file"] = f"{_main_path}/../config.yml" + + if Path(settings_cls.model_config["yaml_file"] or "").is_file(): # type: ignore + return ( + env_settings, + YamlConfigSettingsSource(settings_cls), + ) + if Path(settings_cls.model_config["env_file"] or "").is_file(): # type: ignore + return ( + env_settings, + DotEnvSettingsSource(settings_cls), + ) + return (env_settings,) + + +LogLevelToLower = Annotated[ + Literal["debug", "info", "warn", "error"], + PlainSerializer(lambda v: "".join(v), return_type=str), +] + +HttpUrlToString = Annotated[HttpUrl, PlainSerializer(str, return_type=str)] +TimedeltaInSeconds = Annotated[ + timedelta, PlainSerializer(lambda v: int(v.total_seconds()), return_type=int) +] + + +class ConfigLoaderOAEV(BaseConfigModel): + """OpenAEV/OpenAEV platform configuration settings. + + Contains URL and authentication token for connecting to the OpenAEV platform. + """ + + url: HttpUrlToString = Field( + description="The OpenAEV platform URL.", + ) + token: str = Field( + description="The token for the OpenAEV platform.", + ) + + +class ConfigLoaderCollector(BaseConfigModel): + """Base collector configuration settings. + + Contains common collector settings including identification, logging, + scheduling, and platform information. + """ + + id: str = Field(description="ID of the collector.") + + name: str = Field(description="Name of the collector") + + platform: str | None = Field( + default="SIEM", + description="Platform type for the collector (e.g., EDR, SIEM, etc.).", + ) + log_level: LogLevelToLower | None = Field( + default="error", + description="Determines the verbosity of the logs.", + ) + period: timedelta | None = Field( + default=timedelta(minutes=1), + description="Duration between two scheduled runs of the collector (ISO 8601 format).", + ) + icon_filepath: str | None = Field( + description="Path to the icon file of the collector.", + ) diff --git a/pyoaev/daemons/base_daemon.py b/pyoaev/daemons/base_daemon.py index 549f512..7dcc4e4 100644 --- a/pyoaev/daemons/base_daemon.py +++ b/pyoaev/daemons/base_daemon.py @@ -1,3 +1,4 @@ +import argparse from abc import ABC, abstractmethod from inspect import signature from types import FunctionType @@ -101,6 +102,13 @@ def start(self): follow-up with the main execution loop. Note that at this point, if there is no configured callback, the method will abort and kill the daemon. """ + parser = argparse.ArgumentParser(description="parse daemon options") + parser.add_argument("--dump-config-schema", action="store_true") + args = parser.parse_args() + if args.dump_config_schema: + print(self._configuration.schema()) + return + if self._callback is None: raise OpenAEVError("This daemon has no configured callback.") self._setup() diff --git a/pyoaev/helpers.py b/pyoaev/helpers.py index 082eee1..a422151 100644 --- a/pyoaev/helpers.py +++ b/pyoaev/helpers.py @@ -228,13 +228,23 @@ class PingAlive(utils.PingAlive): ### DEPRECATED class OpenAEVConfigHelper: - def __init__(self, base_path, variables: Dict): - self.__config_obj = Configuration( - config_hints=variables, - config_file_path=os.path.join( - os.path.dirname(os.path.abspath(base_path)), "config.yml" - ), - ) + def __init__(self, base_path, variables: Dict | None, config_obj: Configuration): + if config_obj is not None: + self.__config_obj = config_obj + else: + self.__config_obj = Configuration( + config_hints=variables, + config_file_path=os.path.join( + os.path.dirname(os.path.abspath(base_path)), "config.yml" + ), + ) + + @staticmethod + def from_configuration_object(config: Configuration): + return OpenAEVConfigHelper(None, None, config) + + def get_config_obj(self) -> Configuration: + return self.__config_obj def get_conf(self, variable, is_number=None, default=None, required=None): result = None diff --git a/pyproject.toml b/pyproject.toml index b9daa96..94ab2cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "python_json_logger (>=3.3.0,<3.4.0)", "PyYAML (>=6.0,<6.1)", "pydantic (>=2.11.3,<2.12.0)", + "pydantic-settings (>=2.11.0,<2.12.0)", "requests (>=2.32.3,<2.33.0)", "setuptools (>=80.9.0,<80.10.0)", "cachetools (>=5.5.0,<5.6.0)",