Spaces:
Sleeping
Sleeping
# general imports | |
import os | |
import torch | |
from tqdm import tqdm | |
import plotly.express as px | |
torch.set_grad_enabled(False); | |
# package import | |
from torch import Tensor | |
from transformer_lens import utils | |
from functools import partial | |
from jaxtyping import Int, Float | |
# device setup | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Device: {device}") | |
from transformer_lens import HookedTransformer | |
from sae_lens import SAE | |
# Choose a layer you want to focus on | |
# For this tutorial, we're going to use layer ???? | |
layer = 0 | |
# get model | |
model = HookedTransformer.from_pretrained("tiny-stories-1L-21M", device = device) | |
# get the SAE for this layer | |
sae = SAE.load_from_pretrained("sae_tiny-stories-1L-21M_blocks.0.hook_mlp_out_16384", device = device) | |
# get hook point | |
hook_point = sae.cfg.hook_name | |
print(hook_point) | |
sv_prompt = " Lily" | |
sv_logits, activationCache = model.run_with_cache(sv_prompt, prepend_bos=True) | |
sv_feature_acts = sae.encode(activationCache[hook_point]) | |
print(torch.topk(sv_feature_acts, 3).indices.tolist()) | |
# Generate | |
sv_prompt = " Lily" | |
sv_logits, activationCache = model.run_with_cache(sv_prompt, prepend_bos=True) | |
tokens = model.to_tokens(sv_prompt) | |
print(tokens) | |
# get the feature activations from our SAE | |
sv_feature_acts = sae.encode(activationCache[hook_point]) | |
# get sae_out | |
sae_out = sae.decode(sv_feature_acts) | |
# print out the top activations, focus on the indices | |
print(torch.topk(sv_feature_acts, 3)) | |
# get the neurons to use; | |
print(torch.topk(sv_feature_acts, 3).indices.tolist()) | |
# choose the vector -- find this from the above section | |
# | |
steering_vector = sae.W_dec[10284] | |
example_prompt = "Once upon a time" | |
coeff = 1000 | |
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0) | |
# apply steering vector when the model generates | |
def steering_hook(resid_pre, hook): | |
if resid_pre.shape[1] == 1: | |
return | |
position = sae_out.shape[1] | |
if steering_on: | |
breakpoint() | |
# using our steering vector and applying the coefficient | |
resid_pre[:, :position - 1, :] += coeff * steering_vector | |
def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs): | |
if seed is not None: | |
torch.manual_seed(seed) | |
with model.hooks(fwd_hooks=fwd_hooks): | |
tokenized = model.to_tokens(prompt_batch) | |
result = model.generate( | |
stop_at_eos=False, # avoids a bug on MPS | |
input=tokenized, | |
max_new_tokens=50, | |
do_sample=True, | |
**kwargs) | |
return result | |
def run_generate(example_prompt): | |
model.reset_hooks() | |
editing_hooks = [(f"blocks.{layer}.hook_resid_post", steering_hook)] | |
res = hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs) | |
# Print results, removing the ugly beginning of sequence token | |
res_str = model.to_string(res[:, 1:]) | |
print(("\n\n" + "-" * 80 + "\n\n").join(res_str)) | |
steering_on = True | |
run_generate(example_prompt) | |
# evaluate features | |
import pandas as pd | |
# Let's start by getting the top 10 logits for each feature | |
projection_onto_unembed = sae.W_dec @ model.W_U | |
# get the top 10 logits. | |
vals, inds = torch.topk(projection_onto_unembed, 10, dim=1) | |
# get 10 random features | |
random_indices = torch.randint(0, projection_onto_unembed.shape[0], (10,)) | |
# Show the top 10 logits promoted by those features | |
top_10_logits_df = pd.DataFrame( | |
[model.to_str_tokens(i) for i in inds[random_indices]], | |
index=random_indices.tolist(), | |
).T | |
top_10_logits_df | |
# [7195, 5910, 2041] | |
top_10_associated_words_logits_df = model.to_str_tokens(inds[5910]) | |
# See the words associated with feature 7195 (Should be "Golden") | |