Spaces:
Sleeping
Sleeping
import os | |
import re | |
from typing import Any, Dict, List, Literal, Optional | |
import structlog | |
import yaml | |
from pydantic import BaseModel, Field | |
LOGGER = structlog.getLogger(__name__) | |
_var_matcher = re.compile(r"\${([^}^{]+)}") | |
_tag_matcher = re.compile(r"[^$]*\${([^}^{]+)}.*") | |
class RateLimitConfig(BaseModel): | |
enabled: bool = Field(default=False) | |
limit: str = Field(default="100/minute") | |
class CacheConfig(BaseModel): | |
ttl: int = Field(default=60) | |
max_size: Optional[int] = Field(default=None) | |
class AuthConfig(BaseModel): | |
type: Literal["http_bearer", "http_basic"] = Field() | |
token: Optional[str] = Field(default=None) | |
username: Optional[str] = Field(default=None) | |
password: Optional[str] = Field(default=None) | |
class TracingConfig(BaseModel): | |
exporter: Literal["otel_http", "console"] = Field(default="console") | |
endpoint: Optional[str] = Field(default=None) | |
class MetricsConfig(BaseModel): | |
exporter: Literal["otel_http", "prometheus", "console"] = Field(default="console") | |
endpoint: Optional[str] = Field(default=None) | |
class AppConfig(BaseModel): | |
name: Optional[str] = Field(default="LLM Guard API") | |
port: Optional[int] = Field(default=7860) | |
log_level: Optional[str] = Field(default="INFO") | |
scan_fail_fast: Optional[bool] = Field(default=False) | |
scan_prompt_timeout: Optional[int] = Field(default=10) | |
scan_output_timeout: Optional[int] = Field(default=30) | |
class ScannerConfig(BaseModel): | |
type: str | |
params: Optional[Dict] = Field(default_factory=dict) | |
class Config(BaseModel): | |
input_scanners: List[ScannerConfig] = Field() | |
output_scanners: List[ScannerConfig] = Field() | |
rate_limit: RateLimitConfig = Field(default_factory=RateLimitConfig) | |
cache: CacheConfig = Field(default_factory=CacheConfig) | |
auth: Optional[AuthConfig] = Field(default=None) | |
app: AppConfig = Field(default_factory=AppConfig) | |
tracing: Optional[TracingConfig] = Field(default=None) | |
metrics: Optional[MetricsConfig] = Field(default=None) | |
def _path_constructor(_loader: Any, node: Any): | |
def replace_fn(match): | |
envparts = f"{match.group(1)}:".split(":") | |
return os.environ.get(envparts[0], envparts[1]) | |
return _var_matcher.sub(replace_fn, node.value) | |
def load_yaml(filename: str) -> dict: | |
yaml.add_implicit_resolver("!envvar", _tag_matcher, None, yaml.SafeLoader) | |
yaml.add_constructor("!envvar", _path_constructor, yaml.SafeLoader) | |
try: | |
with open(filename, "r") as f: | |
return yaml.safe_load(f.read()) | |
except (FileNotFoundError, PermissionError, yaml.YAMLError) as exc: | |
LOGGER.error("Error loading YAML file", exception=exc) | |
return dict() | |
def get_config(file_name: str) -> Optional[Config]: | |
LOGGER.debug("Loading config file", file_name=file_name) | |
conf = load_yaml(file_name) | |
if conf == {}: | |
return None | |
return Config(**conf) | |