File size: 5,169 Bytes
d8d14f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from typing import Any, Optional, Callable
from swarms.tools.json_former import Jsonformer
from swarms.utils.loguru_logger import initialize_logger
from swarms.utils.lazy_loader import lazy_import_decorator

logger = initialize_logger(log_folder="tool_agent")


@lazy_import_decorator
class ToolAgent:
    """
    Represents a tool agent that performs a specific task using a model and tokenizer.

    Args:
        name (str): The name of the tool agent.
        description (str): A description of the tool agent.
        model (Any): The model used by the tool agent.
        tokenizer (Any): The tokenizer used by the tool agent.
        json_schema (Any): The JSON schema used by the tool agent.
        *args: Variable length arguments.
        **kwargs: Keyword arguments.

    Attributes:
        name (str): The name of the tool agent.
        description (str): A description of the tool agent.
        model (Any): The model used by the tool agent.
        tokenizer (Any): The tokenizer used by the tool agent.
        json_schema (Any): The JSON schema used by the tool agent.

    Methods:
        run: Runs the tool agent for a specific task.

    Raises:
        Exception: If an error occurs while running the tool agent.


    Example:
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from swarms import ToolAgent


        model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-12b")
        tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b")

        json_schema = {
            "type": "object",
            "properties": {
                "name": {"type": "string"},
                "age": {"type": "number"},
                "is_student": {"type": "boolean"},
                "courses": {
                    "type": "array",
                    "items": {"type": "string"}
                }
            }
        }

        task = "Generate a person's information based on the following schema:"
        agent = ToolAgent(model=model, tokenizer=tokenizer, json_schema=json_schema)
        generated_data = agent.run(task)

        print(generated_data)
    """

    def __init__(
        self,
        name: str = "Function Calling Agent",
        description: str = "Generates a function based on the input json schema and the task",
        model: Any = None,
        tokenizer: Any = None,
        json_schema: Any = None,
        max_number_tokens: int = 500,
        parsing_function: Optional[Callable] = None,
        llm: Any = None,
        *args,
        **kwargs,
    ):
        super().__init__(
            agent_name=name,
            agent_description=description,
            llm=llm,
            **kwargs,
        )
        self.name = name
        self.description = description
        self.model = model
        self.tokenizer = tokenizer
        self.json_schema = json_schema
        self.max_number_tokens = max_number_tokens
        self.parsing_function = parsing_function

    def run(self, task: str, *args, **kwargs):
        """
        Run the tool agent for the specified task.

        Args:
            task (str): The task to be performed by the tool agent.
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.

        Returns:
            The output of the tool agent.

        Raises:
            Exception: If an error occurs during the execution of the tool agent.
        """
        try:
            if self.model:
                logger.info(f"Running {self.name} for task: {task}")
                self.toolagent = Jsonformer(
                    model=self.model,
                    tokenizer=self.tokenizer,
                    json_schema=self.json_schema,
                    llm=self.llm,
                    prompt=task,
                    max_number_tokens=self.max_number_tokens,
                    *args,
                    **kwargs,
                )

                if self.parsing_function:
                    out = self.parsing_function(self.toolagent())
                else:
                    out = self.toolagent()

                return out
            elif self.llm:
                logger.info(f"Running {self.name} for task: {task}")
                self.toolagent = Jsonformer(
                    json_schema=self.json_schema,
                    llm=self.llm,
                    prompt=task,
                    max_number_tokens=self.max_number_tokens,
                    *args,
                    **kwargs,
                )

                if self.parsing_function:
                    out = self.parsing_function(self.toolagent())
                else:
                    out = self.toolagent()

                return out

            else:
                raise Exception(
                    "Either model or llm should be provided to the"
                    " ToolAgent"
                )

        except Exception as error:
            logger.error(
                f"Error running {self.name} for task: {task}"
            )
            raise error

    def __call__(self, task: str, *args, **kwargs):
        return self.run(task, *args, **kwargs)