import os import re from typing import Any, Dict, List, Literal, Optional import structlog import yaml from pydantic import BaseModel, Field LOGGER = structlog.getLogger(__name__) _var_matcher = re.compile(r"\${([^}^{]+)}") _tag_matcher = re.compile(r"[^$]*\${([^}^{]+)}.*") class RateLimitConfig(BaseModel): enabled: bool = Field(default=False) limit: str = Field(default="100/minute") class CacheConfig(BaseModel): ttl: int = Field(default=60) max_size: Optional[int] = Field(default=None) class AuthConfig(BaseModel): type: Literal["http_bearer", "http_basic"] = Field() token: Optional[str] = Field(default=None) username: Optional[str] = Field(default=None) password: Optional[str] = Field(default=None) class TracingConfig(BaseModel): exporter: Literal["otel_http", "console"] = Field(default="console") endpoint: Optional[str] = Field(default=None) class MetricsConfig(BaseModel): exporter: Literal["otel_http", "prometheus", "console"] = Field(default="console") endpoint: Optional[str] = Field(default=None) class AppConfig(BaseModel): name: Optional[str] = Field(default="LLM Guard API") port: Optional[int] = Field(default=7860) log_level: Optional[str] = Field(default="INFO") scan_fail_fast: Optional[bool] = Field(default=False) scan_prompt_timeout: Optional[int] = Field(default=10) scan_output_timeout: Optional[int] = Field(default=30) class ScannerConfig(BaseModel): type: str params: Optional[Dict] = Field(default_factory=dict) class Config(BaseModel): input_scanners: List[ScannerConfig] = Field() output_scanners: List[ScannerConfig] = Field() rate_limit: RateLimitConfig = Field(default_factory=RateLimitConfig) cache: CacheConfig = Field(default_factory=CacheConfig) auth: Optional[AuthConfig] = Field(default=None) app: AppConfig = Field(default_factory=AppConfig) tracing: Optional[TracingConfig] = Field(default=None) metrics: Optional[MetricsConfig] = Field(default=None) def _path_constructor(_loader: Any, node: Any): def replace_fn(match): envparts = f"{match.group(1)}:".split(":") return os.environ.get(envparts[0], envparts[1]) return _var_matcher.sub(replace_fn, node.value) def load_yaml(filename: str) -> dict: yaml.add_implicit_resolver("!envvar", _tag_matcher, None, yaml.SafeLoader) yaml.add_constructor("!envvar", _path_constructor, yaml.SafeLoader) try: with open(filename, "r") as f: return yaml.safe_load(f.read()) except (FileNotFoundError, PermissionError, yaml.YAMLError) as exc: LOGGER.error("Error loading YAML file", exception=exc) return dict() def get_config(file_name: str) -> Optional[Config]: LOGGER.debug("Loading config file", file_name=file_name) conf = load_yaml(file_name) if conf == {}: return None return Config(**conf)