{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0075889",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q git+https://github.com/srush/MiniChain\n",
    "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "033f7bd9",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "desc = \"\"\"\n",
    "### Prompt-aided Language Models\n",
    "\n",
    "Chain for answering complex problems by code generation and execution. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/pal.ipynb)\n",
    "\n",
    "(Adapted from Prompt-aided Language Models [PAL](https://arxiv.org/pdf/2211.10435.pdf)).\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ada25bcd",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "from minichain import prompt, show, OpenAI, Python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3249f5ac",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "@prompt(OpenAI(), template_file=\"pal.pmpt.tpl\")\n",
    "def pal_prompt(model, question):\n",
    "    return model(dict(question=question))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f265c9",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "@prompt(Python())\n",
    "def python(model, inp):\n",
    "    return float(model(inp + \"\\nprint(solution())\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e767c1eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pal(question):\n",
    "    return python(pal_prompt(question))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2a4f241",
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"Melanie is a door-to-door saleswoman. She sold a third of her \" \\\n",
    "    \"vacuum cleaners at the green house, 2 more to the red house, and half of \" \\\n",
    "    \"what was left at the orange house. If Melanie has 5 vacuum cleaners left, \" \\\n",
    "    \"how many did she start with?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c22e7837",
   "metadata": {},
   "outputs": [],
   "source": [
    "gradio = show(pal,\n",
    "              examples=[question],\n",
    "              subprompts=[pal_prompt, python],\n",
    "              description=desc,\n",
    "              out_type=\"json\",\n",
    "              )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8790a65e",
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    gradio.launch()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}