import copy from argparse import ArgumentError, ArgumentParser from collections.abc import Awaitable from contextvars import ContextVar from dataclasses import dataclass, field from io import BytesIO from pathlib import Path from typing import ( IO, Any, Callable, Literal, Optional, TypeVar, Union, cast, ) from pil_utils import BuildImage from pydantic import BaseModel, ValidationError from .exception import ( ArgModelMismatch, ArgParserExit, ImageNumberMismatch, OpenImageFailed, ParserExit, TextNumberMismatch, TextOrNameNotEnough, ) from .utils import is_coroutine_callable, random_image, random_text, run_sync class UserInfo(BaseModel): name: str = "" gender: Literal["male", "female", "unknown"] = "unknown" class MemeArgsModel(BaseModel): user_infos: list[UserInfo] = [] ArgsModel = TypeVar("ArgsModel", bound=MemeArgsModel) MemeFunction = Union[ Callable[[list[BuildImage], list[str], ArgsModel], BytesIO], Callable[[list[BuildImage], list[str], ArgsModel], Awaitable[BytesIO]], ] parser_message: ContextVar[str] = ContextVar("parser_message") class MemeArgsParser(ArgumentParser): """`shell_like` 命令参数解析器,解析出错时不会退出程序。 用法: 用法与 `argparse.ArgumentParser` 相同, 参考文档: [argparse](https://docs.python.org/3/library/argparse.html) """ def _print_message(self, message: str, file: Optional[IO[str]] = None): if (msg := parser_message.get(None)) is not None: parser_message.set(msg + message) else: super()._print_message(message, file) def exit(self, status: int = 0, message: Optional[str] = None): if message: self._print_message(message) raise ParserExit(status=status, error_message=parser_message.get(None)) @dataclass class MemeArgsType: parser: MemeArgsParser model: type[MemeArgsModel] instances: list[MemeArgsModel] = field(default_factory=list) @dataclass class MemeParamsType: min_images: int = 0 max_images: int = 0 min_texts: int = 0 max_texts: int = 0 default_texts: list[str] = field(default_factory=list) args_type: Optional[MemeArgsType] = None @dataclass class Meme: key: str function: MemeFunction params_type: MemeParamsType keywords: list[str] = field(default_factory=list) patterns: list[str] = field(default_factory=list) async def __call__( self, *, images: Union[list[str], list[Path], list[bytes], list[BytesIO]] = [], texts: list[str] = [], args: dict[str, Any] = {}, ) -> BytesIO: if not ( self.params_type.min_images <= len(images) <= self.params_type.max_images ): raise ImageNumberMismatch( self.key, self.params_type.min_images, self.params_type.max_images ) if not (self.params_type.min_texts <= len(texts) <= self.params_type.max_texts): raise TextNumberMismatch( self.key, self.params_type.min_texts, self.params_type.max_texts ) if args_type := self.params_type.args_type: args_model = args_type.model else: args_model = MemeArgsModel try: model = args_model.parse_obj(args) except ValidationError as e: raise ArgModelMismatch(self.key, str(e)) imgs: list[BuildImage] = [] try: for image in images: if isinstance(image, bytes): image = BytesIO(image) imgs.append(BuildImage.open(image)) # type: ignore except Exception as e: raise OpenImageFailed(str(e)) values = {"images": imgs, "texts": texts, "args": model} if is_coroutine_callable(self.function): return await cast(Callable[..., Awaitable[BytesIO]], self.function)( **values ) else: return await run_sync(cast(Callable[..., BytesIO], self.function))(**values) def parse_args(self, args: list[str] = []) -> dict[str, Any]: parser = ( copy.deepcopy(self.params_type.args_type.parser) if self.params_type.args_type else MemeArgsParser() ) parser.add_argument("texts", nargs="*", default=[]) t = parser_message.set("") try: return vars(parser.parse_args(args)) except ArgumentError as e: raise ArgParserExit(self.key, str(e)) except ParserExit as e: raise ArgParserExit(self.key, e.error_message) finally: parser_message.reset(t) async def generate_preview(self, *, args: dict[str, Any] = {}) -> BytesIO: default_images = [random_image() for _ in range(self.params_type.min_images)] default_texts = ( self.params_type.default_texts.copy() if ( self.params_type.min_texts <= len(self.params_type.default_texts) <= self.params_type.max_texts ) else [random_text() for _ in range(self.params_type.min_texts)] ) async def _generate_preview(images: list[BytesIO], texts: list[str]): try: return await self.__call__(images=images, texts=texts, args=args) except TextOrNameNotEnough: texts.append(random_text()) return await _generate_preview(images, texts) return await _generate_preview(default_images, default_texts)