|
import inspect |
|
import logging |
|
import re |
|
from abc import ABCMeta |
|
from copy import deepcopy |
|
from functools import wraps |
|
from typing import Callable, Optional, Type, get_args, get_origin |
|
|
|
try: |
|
from typing import Annotated |
|
except ImportError: |
|
from typing_extensions import Annotated |
|
|
|
from griffe import Docstring |
|
|
|
try: |
|
from griffe import DocstringSectionKind |
|
except ImportError: |
|
from griffe.enumerations import DocstringSectionKind |
|
|
|
from ..schema import ActionReturn, ActionStatusCode |
|
from .parser import BaseParser, JsonParser, ParseError |
|
|
|
logging.getLogger('griffe').setLevel(logging.ERROR) |
|
|
|
|
|
def tool_api(func: Optional[Callable] = None, |
|
*, |
|
explode_return: bool = False, |
|
returns_named_value: bool = False, |
|
**kwargs): |
|
"""Turn functions into tools. It will parse typehints as well as docstrings |
|
to build the tool description and attach it to functions via an attribute |
|
``api_description``. |
|
|
|
Examples: |
|
|
|
.. code-block:: python |
|
|
|
# typehints has higher priority than docstrings |
|
from typing import Annotated |
|
|
|
@tool_api |
|
def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1): |
|
'''Add operation |
|
|
|
Args: |
|
x (int): a |
|
y (int): b |
|
''' |
|
return a + b |
|
|
|
print(add.api_description) |
|
|
|
Args: |
|
func (Optional[Callable]): function to decorate. Defaults to ``None``. |
|
explode_return (bool): whether to flatten the dictionary or tuple return |
|
as the ``return_data`` field. When enabled, it is recommended to |
|
annotate the member in docstrings. Defaults to ``False``. |
|
|
|
.. code-block:: python |
|
|
|
@tool_api(explode_return=True) |
|
def foo(a, b): |
|
'''A simple function |
|
|
|
Args: |
|
a (int): a |
|
b (int): b |
|
|
|
Returns: |
|
dict: information of inputs |
|
* x: value of a |
|
* y: value of b |
|
''' |
|
return {'x': a, 'y': b} |
|
|
|
print(foo.api_description) |
|
|
|
returns_named_value (bool): whether to parse ``thing: Description`` in |
|
returns sections as a name and description, rather than a type and |
|
description. When true, type must be wrapped in parentheses: |
|
``(int): Description``. When false, parentheses are optional but |
|
the items cannot be named: ``int: Description``. Defaults to ``False``. |
|
|
|
Returns: |
|
Callable: wrapped function or partial decorator |
|
|
|
Important: |
|
``return_data`` field will be added to ``api_description`` only |
|
when ``explode_return`` or ``returns_named_value`` is enabled. |
|
""" |
|
|
|
def _detect_type(string): |
|
field_type = 'STRING' |
|
if 'list' in string: |
|
field_type = 'Array' |
|
elif 'str' not in string: |
|
if 'float' in string: |
|
field_type = 'FLOAT' |
|
elif 'int' in string: |
|
field_type = 'NUMBER' |
|
elif 'bool' in string: |
|
field_type = 'BOOLEAN' |
|
return field_type |
|
|
|
def _explode(desc): |
|
kvs = [] |
|
desc = '\nArgs:\n' + '\n'.join([ |
|
' ' + item.lstrip(' -+*#.') |
|
for item in desc.split('\n')[1:] if item.strip() |
|
]) |
|
docs = Docstring(desc).parse('google') |
|
if not docs: |
|
return kvs |
|
if docs[0].kind is DocstringSectionKind.parameters: |
|
for d in docs[0].value: |
|
d = d.as_dict() |
|
if not d['annotation']: |
|
d.pop('annotation') |
|
else: |
|
d['type'] = _detect_type(d.pop('annotation').lower()) |
|
kvs.append(d) |
|
return kvs |
|
|
|
def _parse_tool(function): |
|
|
|
docs = Docstring( |
|
re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse( |
|
'google', returns_named_value=returns_named_value, **kwargs) |
|
desc = dict( |
|
name=function.__name__, |
|
description=docs[0].value |
|
if docs[0].kind is DocstringSectionKind.text else '', |
|
parameters=[], |
|
required=[], |
|
) |
|
args_doc, returns_doc = {}, [] |
|
for doc in docs: |
|
if doc.kind is DocstringSectionKind.parameters: |
|
for d in doc.value: |
|
d = d.as_dict() |
|
d['type'] = _detect_type(d.pop('annotation').lower()) |
|
args_doc[d['name']] = d |
|
if doc.kind is DocstringSectionKind.returns: |
|
for d in doc.value: |
|
d = d.as_dict() |
|
if not d['name']: |
|
d.pop('name') |
|
if not d['annotation']: |
|
d.pop('annotation') |
|
else: |
|
d['type'] = _detect_type(d.pop('annotation').lower()) |
|
returns_doc.append(d) |
|
|
|
sig = inspect.signature(function) |
|
for name, param in sig.parameters.items(): |
|
if name == 'self': |
|
continue |
|
parameter = dict( |
|
name=param.name, |
|
type='STRING', |
|
description=args_doc.get(param.name, |
|
{}).get('description', '')) |
|
annotation = param.annotation |
|
if annotation is inspect.Signature.empty: |
|
parameter['type'] = args_doc.get(param.name, |
|
{}).get('type', 'STRING') |
|
else: |
|
if get_origin(annotation) is Annotated: |
|
annotation, info = get_args(annotation) |
|
if info: |
|
parameter['description'] = info |
|
while get_origin(annotation): |
|
annotation = get_args(annotation) |
|
parameter['type'] = _detect_type(str(annotation)) |
|
desc['parameters'].append(parameter) |
|
if param.default is inspect.Signature.empty: |
|
desc['required'].append(param.name) |
|
|
|
return_data = [] |
|
if explode_return: |
|
return_data = _explode(returns_doc[0]['description']) |
|
elif returns_named_value: |
|
return_data = returns_doc |
|
if return_data: |
|
desc['return_data'] = return_data |
|
return desc |
|
|
|
if callable(func): |
|
|
|
if inspect.iscoroutinefunction(func): |
|
|
|
@wraps(func) |
|
async def wrapper(self, *args, **kwargs): |
|
return await func(self, *args, **kwargs) |
|
|
|
else: |
|
|
|
@wraps(func) |
|
def wrapper(self, *args, **kwargs): |
|
return func(self, *args, **kwargs) |
|
|
|
wrapper.api_description = _parse_tool(func) |
|
return wrapper |
|
|
|
def decorate(func): |
|
|
|
if inspect.iscoroutinefunction(func): |
|
|
|
@wraps(func) |
|
async def wrapper(self, *args, **kwargs): |
|
return await func(self, *args, **kwargs) |
|
|
|
else: |
|
|
|
@wraps(func) |
|
def wrapper(self, *args, **kwargs): |
|
return func(self, *args, **kwargs) |
|
|
|
wrapper.api_description = _parse_tool(func) |
|
return wrapper |
|
|
|
return decorate |
|
|
|
|
|
class ToolMeta(ABCMeta): |
|
"""Metaclass of tools.""" |
|
|
|
def __new__(mcs, name, base, attrs): |
|
is_toolkit, tool_desc = True, dict( |
|
name=name, |
|
description=Docstring(attrs.get('__doc__', |
|
'')).parse('google')[0].value) |
|
for key, value in attrs.items(): |
|
if callable(value) and hasattr(value, 'api_description'): |
|
api_desc = getattr(value, 'api_description') |
|
if key == 'run': |
|
tool_desc['parameters'] = api_desc['parameters'] |
|
tool_desc['required'] = api_desc['required'] |
|
if api_desc['description']: |
|
tool_desc['description'] = api_desc['description'] |
|
if api_desc.get('return_data'): |
|
tool_desc['return_data'] = api_desc['return_data'] |
|
is_toolkit = False |
|
else: |
|
tool_desc.setdefault('api_list', []).append(api_desc) |
|
if not is_toolkit and 'api_list' in tool_desc: |
|
raise KeyError('`run` and other tool APIs can not be implemented ' |
|
'at the same time') |
|
if is_toolkit and 'api_list' not in tool_desc: |
|
is_toolkit = False |
|
if callable(attrs.get('run')): |
|
run_api = tool_api(attrs['run']) |
|
api_desc = run_api.api_description |
|
tool_desc['parameters'] = api_desc['parameters'] |
|
tool_desc['required'] = api_desc['required'] |
|
if api_desc['description']: |
|
tool_desc['description'] = api_desc['description'] |
|
if api_desc.get('return_data'): |
|
tool_desc['return_data'] = api_desc['return_data'] |
|
attrs['run'] = run_api |
|
else: |
|
tool_desc['parameters'], tool_desc['required'] = [], [] |
|
attrs['_is_toolkit'] = is_toolkit |
|
attrs['__tool_description__'] = tool_desc |
|
return super().__new__(mcs, name, base, attrs) |
|
|
|
|
|
class BaseAction(metaclass=ToolMeta): |
|
"""Base class for all actions. |
|
|
|
Args: |
|
description (:class:`Optional[dict]`): The description of the action. |
|
Defaults to ``None``. |
|
parser (:class:`Type[BaseParser]`): The parser class to process the |
|
action's inputs and outputs. Defaults to :class:`JsonParser`. |
|
|
|
Examples: |
|
|
|
* simple tool |
|
|
|
.. code-block:: python |
|
|
|
class Bold(BaseAction): |
|
'''Make text bold''' |
|
|
|
def run(self, text: str): |
|
''' |
|
Args: |
|
text (str): input text |
|
|
|
Returns: |
|
str: bold text |
|
''' |
|
return '**' + text + '**' |
|
|
|
action = Bold() |
|
|
|
* toolkit with multiple APIs |
|
|
|
.. code-block:: python |
|
|
|
class Calculator(BaseAction): |
|
'''Calculator''' |
|
|
|
@tool_api |
|
def add(self, a, b): |
|
'''Add operation |
|
|
|
Args: |
|
a (int): augend |
|
b (int): addend |
|
|
|
Returns: |
|
int: sum |
|
''' |
|
return a + b |
|
|
|
@tool_api |
|
def sub(self, a, b): |
|
'''Subtraction operation |
|
|
|
Args: |
|
a (int): minuend |
|
b (int): subtrahend |
|
|
|
Returns: |
|
int: difference |
|
''' |
|
return a - b |
|
|
|
action = Calculator() |
|
""" |
|
|
|
def __init__( |
|
self, |
|
description: Optional[dict] = None, |
|
parser: Type[BaseParser] = JsonParser, |
|
): |
|
self._description = deepcopy(description or self.__tool_description__) |
|
self._name = self._description['name'] |
|
self._parser = parser(self) |
|
|
|
def __call__(self, inputs: str, name='run') -> ActionReturn: |
|
fallback_args = {'inputs': inputs, 'name': name} |
|
if not hasattr(self, name): |
|
return ActionReturn( |
|
fallback_args, |
|
type=self.name, |
|
errmsg=f'invalid API: {name}', |
|
state=ActionStatusCode.API_ERROR) |
|
try: |
|
inputs = self._parser.parse_inputs(inputs, name) |
|
except ParseError as exc: |
|
return ActionReturn( |
|
fallback_args, |
|
type=self.name, |
|
errmsg=exc.err_msg, |
|
state=ActionStatusCode.ARGS_ERROR) |
|
try: |
|
outputs = getattr(self, name)(**inputs) |
|
except Exception as exc: |
|
return ActionReturn( |
|
inputs, |
|
type=self.name, |
|
errmsg=str(exc), |
|
state=ActionStatusCode.API_ERROR) |
|
if isinstance(outputs, ActionReturn): |
|
action_return = outputs |
|
if not action_return.args: |
|
action_return.args = inputs |
|
if not action_return.type: |
|
action_return.type = self.name |
|
else: |
|
result = self._parser.parse_outputs(outputs) |
|
action_return = ActionReturn(inputs, type=self.name, result=result) |
|
return action_return |
|
|
|
@property |
|
def name(self): |
|
return self._name |
|
|
|
@property |
|
def is_toolkit(self): |
|
return self._is_toolkit |
|
|
|
@property |
|
def description(self) -> dict: |
|
"""Description of the tool.""" |
|
return self._description |
|
|
|
def __repr__(self): |
|
return f'{self.description}' |
|
|
|
__str__ = __repr__ |
|
|
|
|
|
class AsyncActionMixin: |
|
|
|
async def __call__(self, inputs: str, name='run') -> ActionReturn: |
|
fallback_args = {'inputs': inputs, 'name': name} |
|
if not hasattr(self, name): |
|
return ActionReturn( |
|
fallback_args, |
|
type=self.name, |
|
errmsg=f'invalid API: {name}', |
|
state=ActionStatusCode.API_ERROR) |
|
try: |
|
inputs = self._parser.parse_inputs(inputs, name) |
|
except ParseError as exc: |
|
return ActionReturn( |
|
fallback_args, |
|
type=self.name, |
|
errmsg=exc.err_msg, |
|
state=ActionStatusCode.ARGS_ERROR) |
|
try: |
|
outputs = await getattr(self, name)(**inputs) |
|
except Exception as exc: |
|
return ActionReturn( |
|
inputs, |
|
type=self.name, |
|
errmsg=str(exc), |
|
state=ActionStatusCode.API_ERROR) |
|
if isinstance(outputs, ActionReturn): |
|
action_return = outputs |
|
if not action_return.args: |
|
action_return.args = inputs |
|
if not action_return.type: |
|
action_return.type = self.name |
|
else: |
|
result = self._parser.parse_outputs(outputs) |
|
action_return = ActionReturn(inputs, type=self.name, result=result) |
|
return action_return |
|
|