Source code for tiled.config

"""
This module handles server configuration.

See profiles.py for client configuration.
"""

import copy
from datetime import timedelta
from functools import cached_property
from pathlib import Path
from typing import Annotated, Any, Iterator, Optional, Union

from pydantic import BaseModel, Field, ImportString, field_validator, model_validator

from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator
from tiled.type_aliases import AppTask, TaskMap

from .adapters.mapping import MapAdapter
from .catalog import from_uri, in_memory
from .media_type_registration import (
    SerializationRegistry,
    default_compression_registry,
    default_deserialization_registry,
    default_serialization_registry,
)
from .query_registration import default_query_registry
from .server.settings import get_settings
from .structures.core import Spec
from .utils import parse, prepend_to_sys_path
from .validation_registration import ValidationRegistry, default_validation_registry

TREE_ALIASES = {"catalog": "tiled.catalog:from_uri"}


def sub_paths(segments: tuple[str, ...]) -> Iterator[tuple[str, ...]]:
    for i in range(len(segments)):
        yield segments[:i]


class TreeSpec(BaseModel):
    tree_type: Annotated[ImportString, Field(alias="tree")]
    path: str
    args: Optional[dict[str, Any]] = None

    @model_validator(mode="after")
    def check_callable(self):
        if self.args and not callable(self.tree_type):
            raise ValueError(
                f"Tree type '{self.tree_type.__class__}' is not callable and cannot take args"
            )
        return self

    @property
    def startup_tasks(self) -> list[AppTask]:
        return getattr(self.tree, "startup_tasks", [])

    @property
    def shutdown_tasks(self) -> list[AppTask]:
        return getattr(self.tree, "shutdown_tasks", [])

    @property
    def background_tasks(self) -> list[AppTask]:
        return getattr(self.tree, "background_tasks", [])

    @cached_property
    def segments(self) -> tuple[str, ...]:
        return tuple(segment for segment in self.path.split("/") if segment)

    @cached_property
    def tree(self) -> Any:
        if callable(self.tree_type):
            return self.tree_type(**self.args or {})
        return self.tree_type

    @property
    def tree_entry(self) -> tuple[tuple[str, ...], Any]:
        return (self.segments, self.tree)

    @field_validator("tree_type", mode="before")
    @classmethod
    def tree_alias(cls, value: Any) -> Any:
        return TREE_ALIASES.get(value, value)


class AuthenticationProviderSpec(BaseModel):
    provider: str
    authenticator: ImportString
    args: Optional[dict[str, Any]] = None

    def into_auth_entry(
        self,
    ) -> tuple[str, Union[InternalAuthenticator, ExternalAuthenticator]]:
        auth = self.authenticator(**(self.args or {}))
        if not isinstance(auth, (InternalAuthenticator, ExternalAuthenticator)):
            raise ValueError(f"Type {auth.__class__} is not a known authenticator type")
        return (self.provider, auth)


class TiledAdmin(BaseModel):
    provider: str
    id: str


class Authentication(BaseModel):
    # Defaults are all left as None to differentiate between unset and set to the default
    providers: Optional[list[AuthenticationProviderSpec]] = None
    tiled_admins: Optional[list[TiledAdmin]] = None
    secret_keys: Optional[list[str]] = None
    allow_anonymous_access: Optional[bool] = None
    single_user_api_key: Annotated[Optional[str], Field(pattern="[a-zA-Z0-9]+")] = None
    access_token_max_age: Optional[timedelta] = None
    refresh_token_max_age: Optional[timedelta] = None
    session_max_age: Optional[timedelta] = None

    @field_validator("providers", mode="after")
    @classmethod
    def check_unique_names(
        cls, value: list[AuthenticationProviderSpec]
    ) -> list[AuthenticationProviderSpec]:
        if value is not None:
            if len(value) != len(set(s.provider for s in value)):
                raise ValueError("Authenticator provider names must be unique")
        return value or []

    @cached_property
    def authenticators(
        self,
    ) -> dict[str, Union[InternalAuthenticator, ExternalAuthenticator]]:
        return dict(auth.into_auth_entry() for auth in self.providers or ())


class Database(BaseModel):
    uri: Optional[str] = None
    init_if_not_exists: Optional[bool] = None
    pool_pre_ping: Optional[bool] = None
    pool_size: Annotated[Optional[int], Field(ge=2)] = None
    max_overflow: Optional[int] = None


class AccessControl(BaseModel):
    access_policy: ImportString
    args: Optional[dict[str, Any]]

    def build(self):
        return self.access_policy(**(self.args or {}))


class MetricsConfig(BaseModel):
    prometheus: bool = True


class ValidationSpec(BaseModel):
    spec: str
    validator: Optional[ImportString] = None


class StreamingCache(BaseModel):
    uri: str
    data_ttl: Optional[int] = None
    seq_ttl: Optional[int] = None
    socket_timeout: Optional[int] = None
    socket_connect_timeout: Optional[int] = None


class Config(BaseModel):
    trees: list[TreeSpec]
    media_types: dict[str, dict[str, ImportString]] = {}
    file_extensions: dict[str, str] = {}
    authentication: Authentication = Authentication()
    database: Optional[Database] = None
    # TODO: Replace Any with AccessPolicy when #1044 is merged
    access_policy: Annotated[Optional[Any], Field(alias="access_control")] = None
    response_bytesize_limit: int = 300_000_000
    exact_count_limit: Optional[int] = None
    allow_origins: Optional[list[str]] = None
    uvicorn: dict[str, Any] = {}
    metrics: MetricsConfig = MetricsConfig()
    specs: list[ValidationSpec] = []
    reject_undeclared_specs: bool = False
    expose_raw_assets: bool = True

    catalog_pool_size: Optional[int] = None
    storage_pool_size: Optional[int] = None
    catalog_max_overflow: Optional[int] = None
    storage_max_overflow: Optional[int] = None

    streaming_cache: Optional[StreamingCache] = None

    @field_validator("access_policy")
    @classmethod
    def check_access_policy(cls, value: Any) -> Any:
        """Convert the access policy spec into the construct instance"""
        access = AccessControl.model_validate(value)
        return access.build()

    @field_validator("trees")
    @classmethod
    def non_overlapping_trees(cls, trees: list[TreeSpec]) -> list[TreeSpec]:
        """Ensure that paths to trees do not collide"""
        paths = set()
        for path in sorted((t.segments for t in trees), key=len):
            for sub in (*sub_paths(path), path):
                if sub in paths:
                    raise ValueError(
                        f"Tree paths cannot be subpaths of each other: '/{'/'.join(sub)}' and '/{'/'.join(path)}'"
                    )
            paths.add(path)
        return trees

    @model_validator(mode="after")
    def fudge_tree_args(self):
        # Needing to fudge the args of tree specs is awful
        for tree in self.trees:
            tree.args = tree.args or {}
            if tree.tree_type is from_uri:
                defaults = get_settings()
                # Assumes none of the values can be 0
                tree.args["catalog_pool_size"] = (
                    self.catalog_pool_size or defaults.catalog_pool_size
                )
                tree.args["storage_pool_size"] = (
                    self.storage_pool_size or defaults.storage_pool_size
                )
                tree.args["catalog_max_overflow"] = (
                    self.catalog_max_overflow or defaults.catalog_max_overflow
                )
                tree.args["storage_max_overflow"] = (
                    self.storage_max_overflow or defaults.storage_max_overflow
                )
            if tree.tree_type in (from_uri, in_memory):
                tree.args["cache_settings"] = (
                    self.streaming_cache.model_dump() if self.streaming_cache else None
                )
        return self

    @property
    def root_path(self) -> str:
        return self.uvicorn.get("root_path") or ""

    @cached_property
    def merged_trees(self) -> Any:  # TODO: update when # 1047 is merged
        trees = dict(tree.tree_entry for tree in self.trees)
        if list(trees) == [()]:
            # Simple case: there is one tree, served at the root path /.
            root_tree = trees[()]
        else:
            # There are one or more tree(s) to be served at sub-paths so merge
            # them into one root MapAdapter with map path segments to dicts
            # containing Adapters at that path.
            root_mapping = trees.pop((), {})
            index: dict[tuple[str, ...], dict] = {(): root_mapping}
            all_routers = set()

            # for rest of trees, build up parent nodes if required
            for segments, tree in trees.items():
                for subpath in sub_paths(segments):
                    if subpath not in index:
                        mapping = {}
                        index[subpath] = mapping
                        index[subpath[:-1]][subpath[-1]] = MapAdapter(mapping)
                index[segments[:-1]][segments[-1]] = tree
                tree_routers = getattr(tree, "include_routers", [])
                all_routers.update(tree_routers)

            root_tree = MapAdapter(root_mapping)
            root_tree.include_routers.extend(all_routers)
        return root_tree

    def tree_tasks(self) -> TaskMap:
        startup = []
        shutdown = []
        background = []
        for tree in self.trees:
            startup.extend(tree.startup_tasks)
            shutdown.extend(tree.shutdown_tasks)
            background.extend(tree.background_tasks)
        return {
            "startup": startup,
            "shutdown": shutdown,
            "background": background,
        }

    def serialization_registry(self) -> SerializationRegistry:
        base = copy.deepcopy(default_serialization_registry)
        for family, types in self.media_types.items():
            for typ, func in types.items():
                base.register(family, typ, func)
        for ext, media_type in self.file_extensions.items():
            base.register_alias(ext, media_type)
        return base

    def validation_registry(self) -> ValidationRegistry:
        base = copy.deepcopy(default_validation_registry)
        for spec in self.specs:
            base.register(Spec(spec.spec), spec.validator or _no_op_validator)
        return base


[docs] def parse_configs(src_file: Union[str, Path]) -> Config: src_file = Path(src_file) if src_file.is_dir(): conf = {} for f in src_file.iterdir(): if f.is_file() and f.suffix == ".yml": new_config = parse(f) if common := new_config.keys() & conf.keys(): raise ValueError(f"Duplicate configuration for {common} in {f}") conf.update(new_config) else: conf = parse(src_file) with prepend_to_sys_path(src_file if src_file.is_dir() else src_file.parent): return Config.model_validate(conf)
[docs] def construct_build_app_kwargs(config: Config): server_settings = dict( root_path=config.root_path, allow_origins=config.allow_origins, response_bytesize_limit=config.response_bytesize_limit, exact_count_limit=config.exact_count_limit, database=config.database, reject_undeclared_specs=config.reject_undeclared_specs, expose_raw_assets=config.expose_raw_assets, metrics=config.metrics, ) return dict( tree=config.merged_trees, authentication=config.authentication, server_settings=server_settings, query_registry=default_query_registry, serialization_registry=config.serialization_registry(), deserialization_registry=default_deserialization_registry, compression_registry=default_compression_registry, validation_registry=config.validation_registry(), tasks=config.tree_tasks(), access_policy=config.access_policy, )
def _no_op_validator(*args, **kwargs): return None