{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "GoXn14ltnGh3" }, "source": [ "# Using an SAE as a steering vector\n", "\n", "This notebook demonstrates how to use SAE lens to identify a feature on a pretrained model, and then construct a steering vector to affect the models output to various prompts. This notebook will also make use of Neuronpedia for identifying features of interest.\n", "\n", "The steps below include:\n", "\n", "\n", "\n", "* Installing relevant packages (Colab or locally)\n", "* Load your SAE and the model it used\n", "* Determining your feature of interest and its index\n", "* Implementing your steering vector\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "gf3lJYPEXh0v" }, "source": [ "## Setting up packages and notebook" ] }, { "cell_type": "markdown", "metadata": { "id": "l9k5iGyOXtuN" }, "source": [ "### Import and installs" ] }, { "cell_type": "markdown", "metadata": { "id": "fapxk8MDrs6R" }, "source": [ "#### Environment Setup\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "collapsed": true, "id": "0TwNmRkRUgR7", "outputId": "ffeb827a-9af2-4b09-b8dd-78e0d594ddf6" }, "outputs": [], "source": [ "try:\n", " # for google colab users\n", " import google.colab # type: ignore\n", " from google.colab import output\n", " COLAB = True\n", " %pip install sae-lens transformer-lens\n", "except:\n", " # for local setup\n", " COLAB = False\n", " from IPython import get_ipython # type: ignore\n", " ipython = get_ipython(); assert ipython is not None\n", " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", " ipython.run_line_magic(\"autoreload\", \"2\")\n", "\n", "# Imports for displaying vis in Colab / notebook\n", "import webbrowser\n", "import http.server\n", "import socketserver\n", "import threading\n", "PORT = 8000\n", "\n", "# general imports\n", "import os\n", "import torch\n", "from tqdm import tqdm\n", "import plotly.express as px\n", "\n", "torch.set_grad_enabled(False);" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "NGgIu1ZVYDub" }, "outputs": [], "source": [ "def display_vis_inline(filename: str, height: int = 850):\n", " '''\n", " Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each\n", " vis has a unique port without having to define a port within the function.\n", " '''\n", " if not(COLAB):\n", " webbrowser.open(filename);\n", "\n", " else:\n", " global PORT\n", "\n", " def serve(directory):\n", " os.chdir(directory)\n", "\n", " # Create a handler for serving files\n", " handler = http.server.SimpleHTTPRequestHandler\n", "\n", " # Create a socket server with the handler\n", " with socketserver.TCPServer((\"\", PORT), handler) as httpd:\n", " print(f\"Serving files from {directory} on port {PORT}\")\n", " httpd.serve_forever()\n", "\n", " thread = threading.Thread(target=serve, args=(\"/content\",))\n", " thread.start()\n", "\n", " output.serve_kernel_port_as_iframe(PORT, path=f\"/{filename}\", height=height, cache_in_notebook=True)\n", "\n", " PORT += 1" ] }, { "cell_type": "markdown", "metadata": { "id": "CmaPYLpGrxbo" }, "source": [ "#### General Installs and device setup" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tdUm9rZKr1Qb", "outputId": "9b73b762-1356-437b-8925-91c514093b43" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device: mps\n" ] } ], "source": [ "# package import\n", "from torch import Tensor\n", "from transformer_lens import utils\n", "from functools import partial\n", "from jaxtyping import Int, Float\n", "\n", "# device setup\n", "if torch.backends.mps.is_available():\n", " device = \"mps\"\n", "else:\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "print(f\"Device: {device}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "lsB0qORUaXiK" }, "source": [ "### Load your model and SAE\n", "\n", "We're going to work with a pretrained GPT2-small model, and the RES-JB SAE set which is for the residual stream." ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "collapsed": true, "id": "bCvNtm1OOhlR", "outputId": "e6fd27ab-ee94-46ec-a07e-ee48c8f30da3" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8607cfc3f17548078c7b3ff7ebcca055", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00\", \"J\", and \"edi\".\n", "\n", "Our feature activation indexes at sv_feature_acts[2] - for \"edi\" - are of most interest to us.\n", "\n", "Because we are using pretrained saes that have published feature maps, you can search on Neuronpedia for a feature of interest." ] }, { "cell_type": "markdown", "metadata": { "id": "gFv4iBHFcOmE" }, "source": [ "### Steps for Neuronpedia use\n", "\n", "Use the interface to search for a specific concept or item and determine which layer and at what index it is.\n", "\n", "1. Open the [Neuronpedia](https://www.neuronpedia.org/) homepage.\n", "2. Using the \"Models\" dropdown, select your model. Here we are using GPT2-SM (GPT2-small).\n", "3. The next page will have a search bar, which allows you to enter your index of interest. We're interested in the \"RES-JB\" SAE set, make sure to select it.\n", "4. We found these indices in the previous step: [ 7650, 718, 22372]. Select them in the search to see the feature dashboard for each.\n", "5. As we'll see, some of the indices may relate to features you don't care about.\n", "\n", "From using Neuronpedia, I have determined that my feature of interest is in layer 2, at index 7650: [here](https://www.neuronpedia.org/gpt2-small/2-res-jb/7650) is the feature." ] }, { "cell_type": "markdown", "metadata": { "id": "KX0rXziniH9O" }, "source": [ "### Note: 2nd Option - Starting with Neuronpedia\n", "\n", "Another option here is that you can start with Neuronpedia to identify features of interest. By using your prompt in the interface you can explore which features were involved and search across all the layers. This allows you to first determine your layer and index of interest in Neuronpedia before focusing them in your code. Start [here](https://www.neuronpedia.org/search) if you want to begin with search." ] }, { "cell_type": "markdown", "metadata": { "id": "YACtNFzGcNua" }, "source": [ "## Implement your steering vector and affect the output" ] }, { "cell_type": "markdown", "metadata": { "id": "pO8hjg8j5bb-" }, "source": [ "### Define values for your steering vector\n", "To create our steering vector, we now need to get the decoder weights from our sparse autoencoder found at our index of interest.\n", "\n", "Then to use our steering vector, we want a prompt for text generation, as well as a scaling factor coefficent to apply with the steering vector\n", "\n", "We also set common sampling kwargs - temperature, top_p and freq_penalty" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "id": "rgYEWGV0t0L2" }, "outputs": [], "source": [ "steering_vector = sae.W_dec[10200]\n", "\n", "example_prompt = \"What is the most iconic structure known to man?\"\n", "coeff = 300\n", "sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)" ] }, { "cell_type": "markdown", "metadata": { "id": "cexaoBR65lIa" }, "source": [ "### Set up hook functions\n", "\n", "Finally, we need to create a hook that allows us to apply the steering vector when our model runs generate() on our defined prompt. We have also added a boolean value 'steering_on' that allows us to easily toggle the steering vector on and off for each prompt\n" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "collapsed": true, "id": "3kcVWeJoIAlC" }, "outputs": [], "source": [ "def steering_hook(resid_pre, hook):\n", " if resid_pre.shape[1] == 1:\n", " return\n", "\n", " position = sae_out.shape[1]\n", " if steering_on:\n", " # using our steering vector and applying the coefficient\n", " resid_pre[:, :position - 1, :] += coeff * steering_vector\n", "\n", "\n", "def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):\n", " if seed is not None:\n", " torch.manual_seed(seed)\n", "\n", " with model.hooks(fwd_hooks=fwd_hooks):\n", " tokenized = model.to_tokens(prompt_batch)\n", " result = model.generate(\n", " stop_at_eos=False, # avoids a bug on MPS\n", " input=tokenized,\n", " max_new_tokens=50,\n", " do_sample=True,\n", " **kwargs)\n", " return result\n" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "id": "VcuRkX0yA2WH" }, "outputs": [], "source": [ "def run_generate(example_prompt):\n", " model.reset_hooks()\n", " editing_hooks = [(f\"blocks.{layer}.hook_resid_post\", steering_hook)]\n", " res = hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs)\n", "\n", " # Print results, removing the ugly beginning of sequence token\n", " res_str = model.to_string(res[:, 1:])\n", " print((\"\\n\\n\" + \"-\" * 80 + \"\\n\\n\").join(res_str))" ] }, { "cell_type": "markdown", "metadata": { "id": "XYx--hIn61VQ" }, "source": [ "### Generate text influenced by steering vector\n", "\n", "You may want to experiment with the scaling factor coefficient value that you set and see how it affects the generated output." ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 337, "referenced_widgets": [ "9f555c5ada38495eb4281cbb49169abe", "79b59cbde9444bf892931d31afec7f2a", "a157870318114d459a33d795850967ef", "635162e10abc441797d4e5b74713bf44", "720b4d010c364e3fbf72a53b267e8db9", "d9c33fbfb3164cbbb7b9a4cd172d20ae", "df53331cce124bd1ada5aa9e9a977015", "229dad8e29f04c279c5603286e2c0643", "83d947fc3338491ab4155b87c443884c", "5e9700580d6b4ad0bfac34bf3b3919fc", "a2c30462ef8d41fd9158f194a746d5a7" ] }, "id": "hN_YOzBE6lz8", "outputId": "e263b8ff-86ce-439e-81e5-bbecb0d7e187" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "634ddfad68cb49208e63733402859842", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00