|
|
|
|
|
desc = """ |
|
### Agent |
|
|
|
Chain that executes different tools based on model decisions. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/bash.ipynb) |
|
|
|
(Adapted from LangChain ) |
|
""" |
|
|
|
|
|
|
|
|
|
from minichain import Id, prompt, OpenAI, show, transform, Mock, Break |
|
from gradio_tools.tools import StableDiffusionTool, ImageCaptioningTool, ImageToMusicTool |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools = [StableDiffusionTool(), ImageCaptioningTool(), ImageToMusicTool()] |
|
|
|
|
|
@prompt(OpenAI(stop=["Observation:"]), |
|
template_file="agent.pmpt.tpl") |
|
def agent(model, query, history): |
|
return model(dict(tools=[(str(tool.__class__.__name__), tool.description) |
|
for tool in tools], |
|
input=query, |
|
agent_scratchpad=history |
|
)) |
|
@transform() |
|
def tool_parse(out): |
|
lines = out.split("\n") |
|
if lines[0].split("?")[-1].strip() == "Yes": |
|
tool = lines[1].split(":", 1)[-1].strip() |
|
command = lines[2].split(":", 1)[-1].strip() |
|
return tool, command |
|
else: |
|
return Break() |
|
|
|
@prompt(tools) |
|
def tool_use(model, usage): |
|
selector, command = usage |
|
for i, tool in enumerate(tools): |
|
if selector == tool.__class__.__name__: |
|
return model(command, tool_num=i) |
|
return ("",) |
|
|
|
@transform() |
|
def append(history, new, observation): |
|
return history + "\n" + new + "Observation: " + observation |
|
|
|
def run(query): |
|
history = "" |
|
observations = [] |
|
for i in range(3): |
|
select_input = agent(query, history) |
|
observations.append(tool_use(tool_parse(select_input))) |
|
history = append(history, select_input, observations[i]) |
|
|
|
return observations[-1] |
|
|
|
|
|
|
|
gradio = show(run, |
|
subprompts=[agent, tool_use] * 3, |
|
examples=[ |
|
"I would please like a photo of a dog riding a skateboard. " |
|
"Please caption this image and create a song for it.", |
|
'Use an image generator tool to draw a cat.', |
|
'Caption the image https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png from the internet'], |
|
out_type="markdown", |
|
description=desc, |
|
show_advanced=False |
|
) |
|
if __name__ == "__main__": |
|
gradio.queue().launch() |
|
|
|
|