L2_Agent / lagent /actions /base_action.py
Pluto0616's picture
L2_agent
499a238
import inspect
import logging
import re
from abc import ABCMeta
from copy import deepcopy
from functools import wraps
from typing import Callable, Optional, Type, get_args, get_origin
try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated
from griffe import Docstring
try:
from griffe import DocstringSectionKind
except ImportError:
from griffe.enumerations import DocstringSectionKind
from ..schema import ActionReturn, ActionStatusCode
from .parser import BaseParser, JsonParser, ParseError
logging.getLogger('griffe').setLevel(logging.ERROR)
def tool_api(func: Optional[Callable] = None,
*,
explode_return: bool = False,
returns_named_value: bool = False,
**kwargs):
"""Turn functions into tools. It will parse typehints as well as docstrings
to build the tool description and attach it to functions via an attribute
``api_description``.
Examples:
.. code-block:: python
# typehints has higher priority than docstrings
from typing import Annotated
@tool_api
def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
'''Add operation
Args:
x (int): a
y (int): b
'''
return a + b
print(add.api_description)
Args:
func (Optional[Callable]): function to decorate. Defaults to ``None``.
explode_return (bool): whether to flatten the dictionary or tuple return
as the ``return_data`` field. When enabled, it is recommended to
annotate the member in docstrings. Defaults to ``False``.
.. code-block:: python
@tool_api(explode_return=True)
def foo(a, b):
'''A simple function
Args:
a (int): a
b (int): b
Returns:
dict: information of inputs
* x: value of a
* y: value of b
'''
return {'x': a, 'y': b}
print(foo.api_description)
returns_named_value (bool): whether to parse ``thing: Description`` in
returns sections as a name and description, rather than a type and
description. When true, type must be wrapped in parentheses:
``(int): Description``. When false, parentheses are optional but
the items cannot be named: ``int: Description``. Defaults to ``False``.
Returns:
Callable: wrapped function or partial decorator
Important:
``return_data`` field will be added to ``api_description`` only
when ``explode_return`` or ``returns_named_value`` is enabled.
"""
def _detect_type(string):
field_type = 'STRING'
if 'list' in string:
field_type = 'Array'
elif 'str' not in string:
if 'float' in string:
field_type = 'FLOAT'
elif 'int' in string:
field_type = 'NUMBER'
elif 'bool' in string:
field_type = 'BOOLEAN'
return field_type
def _explode(desc):
kvs = []
desc = '\nArgs:\n' + '\n'.join([
' ' + item.lstrip(' -+*#.')
for item in desc.split('\n')[1:] if item.strip()
])
docs = Docstring(desc).parse('google')
if not docs:
return kvs
if docs[0].kind is DocstringSectionKind.parameters:
for d in docs[0].value:
d = d.as_dict()
if not d['annotation']:
d.pop('annotation')
else:
d['type'] = _detect_type(d.pop('annotation').lower())
kvs.append(d)
return kvs
def _parse_tool(function):
# remove rst syntax
docs = Docstring(
re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
'google', returns_named_value=returns_named_value, **kwargs)
desc = dict(
name=function.__name__,
description=docs[0].value
if docs[0].kind is DocstringSectionKind.text else '',
parameters=[],
required=[],
)
args_doc, returns_doc = {}, []
for doc in docs:
if doc.kind is DocstringSectionKind.parameters:
for d in doc.value:
d = d.as_dict()
d['type'] = _detect_type(d.pop('annotation').lower())
args_doc[d['name']] = d
if doc.kind is DocstringSectionKind.returns:
for d in doc.value:
d = d.as_dict()
if not d['name']:
d.pop('name')
if not d['annotation']:
d.pop('annotation')
else:
d['type'] = _detect_type(d.pop('annotation').lower())
returns_doc.append(d)
sig = inspect.signature(function)
for name, param in sig.parameters.items():
if name == 'self':
continue
parameter = dict(
name=param.name,
type='STRING',
description=args_doc.get(param.name,
{}).get('description', ''))
annotation = param.annotation
if annotation is inspect.Signature.empty:
parameter['type'] = args_doc.get(param.name,
{}).get('type', 'STRING')
else:
if get_origin(annotation) is Annotated:
annotation, info = get_args(annotation)
if info:
parameter['description'] = info
while get_origin(annotation):
annotation = get_args(annotation)
parameter['type'] = _detect_type(str(annotation))
desc['parameters'].append(parameter)
if param.default is inspect.Signature.empty:
desc['required'].append(param.name)
return_data = []
if explode_return:
return_data = _explode(returns_doc[0]['description'])
elif returns_named_value:
return_data = returns_doc
if return_data:
desc['return_data'] = return_data
return desc
if callable(func):
if inspect.iscoroutinefunction(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
return await func(self, *args, **kwargs)
else:
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
wrapper.api_description = _parse_tool(func)
return wrapper
def decorate(func):
if inspect.iscoroutinefunction(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
return await func(self, *args, **kwargs)
else:
@wraps(func)
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
wrapper.api_description = _parse_tool(func)
return wrapper
return decorate
class ToolMeta(ABCMeta):
"""Metaclass of tools."""
def __new__(mcs, name, base, attrs):
is_toolkit, tool_desc = True, dict(
name=name,
description=Docstring(attrs.get('__doc__',
'')).parse('google')[0].value)
for key, value in attrs.items():
if callable(value) and hasattr(value, 'api_description'):
api_desc = getattr(value, 'api_description')
if key == 'run':
tool_desc['parameters'] = api_desc['parameters']
tool_desc['required'] = api_desc['required']
if api_desc['description']:
tool_desc['description'] = api_desc['description']
if api_desc.get('return_data'):
tool_desc['return_data'] = api_desc['return_data']
is_toolkit = False
else:
tool_desc.setdefault('api_list', []).append(api_desc)
if not is_toolkit and 'api_list' in tool_desc:
raise KeyError('`run` and other tool APIs can not be implemented '
'at the same time')
if is_toolkit and 'api_list' not in tool_desc:
is_toolkit = False
if callable(attrs.get('run')):
run_api = tool_api(attrs['run'])
api_desc = run_api.api_description
tool_desc['parameters'] = api_desc['parameters']
tool_desc['required'] = api_desc['required']
if api_desc['description']:
tool_desc['description'] = api_desc['description']
if api_desc.get('return_data'):
tool_desc['return_data'] = api_desc['return_data']
attrs['run'] = run_api
else:
tool_desc['parameters'], tool_desc['required'] = [], []
attrs['_is_toolkit'] = is_toolkit
attrs['__tool_description__'] = tool_desc
return super().__new__(mcs, name, base, attrs)
class BaseAction(metaclass=ToolMeta):
"""Base class for all actions.
Args:
description (:class:`Optional[dict]`): The description of the action.
Defaults to ``None``.
parser (:class:`Type[BaseParser]`): The parser class to process the
action's inputs and outputs. Defaults to :class:`JsonParser`.
Examples:
* simple tool
.. code-block:: python
class Bold(BaseAction):
'''Make text bold'''
def run(self, text: str):
'''
Args:
text (str): input text
Returns:
str: bold text
'''
return '**' + text + '**'
action = Bold()
* toolkit with multiple APIs
.. code-block:: python
class Calculator(BaseAction):
'''Calculator'''
@tool_api
def add(self, a, b):
'''Add operation
Args:
a (int): augend
b (int): addend
Returns:
int: sum
'''
return a + b
@tool_api
def sub(self, a, b):
'''Subtraction operation
Args:
a (int): minuend
b (int): subtrahend
Returns:
int: difference
'''
return a - b
action = Calculator()
"""
def __init__(
self,
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
):
self._description = deepcopy(description or self.__tool_description__)
self._name = self._description['name']
self._parser = parser(self)
def __call__(self, inputs: str, name='run') -> ActionReturn:
fallback_args = {'inputs': inputs, 'name': name}
if not hasattr(self, name):
return ActionReturn(
fallback_args,
type=self.name,
errmsg=f'invalid API: {name}',
state=ActionStatusCode.API_ERROR)
try:
inputs = self._parser.parse_inputs(inputs, name)
except ParseError as exc:
return ActionReturn(
fallback_args,
type=self.name,
errmsg=exc.err_msg,
state=ActionStatusCode.ARGS_ERROR)
try:
outputs = getattr(self, name)(**inputs)
except Exception as exc:
return ActionReturn(
inputs,
type=self.name,
errmsg=str(exc),
state=ActionStatusCode.API_ERROR)
if isinstance(outputs, ActionReturn):
action_return = outputs
if not action_return.args:
action_return.args = inputs
if not action_return.type:
action_return.type = self.name
else:
result = self._parser.parse_outputs(outputs)
action_return = ActionReturn(inputs, type=self.name, result=result)
return action_return
@property
def name(self):
return self._name
@property
def is_toolkit(self):
return self._is_toolkit
@property
def description(self) -> dict:
"""Description of the tool."""
return self._description
def __repr__(self):
return f'{self.description}'
__str__ = __repr__
class AsyncActionMixin:
async def __call__(self, inputs: str, name='run') -> ActionReturn:
fallback_args = {'inputs': inputs, 'name': name}
if not hasattr(self, name):
return ActionReturn(
fallback_args,
type=self.name,
errmsg=f'invalid API: {name}',
state=ActionStatusCode.API_ERROR)
try:
inputs = self._parser.parse_inputs(inputs, name)
except ParseError as exc:
return ActionReturn(
fallback_args,
type=self.name,
errmsg=exc.err_msg,
state=ActionStatusCode.ARGS_ERROR)
try:
outputs = await getattr(self, name)(**inputs)
except Exception as exc:
return ActionReturn(
inputs,
type=self.name,
errmsg=str(exc),
state=ActionStatusCode.API_ERROR)
if isinstance(outputs, ActionReturn):
action_return = outputs
if not action_return.args:
action_return.args = inputs
if not action_return.type:
action_return.type = self.name
else:
result = self._parser.parse_outputs(outputs)
action_return = ActionReturn(inputs, type=self.name, result=result)
return action_return