steer-hexter / app.py
LeeHarrold's picture
Upload folder using huggingface_hub
f01dd44 verified
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
class Inference:
def __init__(self, model, pretrained_sae, layer):
self.layer = layer
if model == "gemma-2b":
self.sae_id = f"blocks.{layer}.hook_resid_post"
elif model == "gpt2-small":
print(f"using {model}")
self.sae_id = f"blocks.{0}.hook_resid_pre"
self.sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)
self.set_coeff(1)
self.set_model(model)
self.set_SAE(pretrained_sae)
def set_model(self, model):
self.model = HookedTransformer.from_pretrained(model, device = device)
def set_coeff(self, coeff):
self.coeff = coeff
def set_temperature(self, temperature):
self.sampling_kwargs['temperature'] = temperature
def set_steering_vector_prompt(self, prompt: str):
self.steering_vector_prompt = prompt
def set_SAE(self, sae_name):
sae, cfg_dict, _ = SAE.from_pretrained(
release = sae_name,
sae_id = self.sae_id,
device = device
)
self.sae = sae
self.cfg_dict = cfg_dict
def _get_sae_out_and_feature_activations(self):
# given the words in steering_vector_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated
sv_logits, activationCache = self.model.run_with_cache(self.steering_vector_prompt, prepend_bos=True)
sv_feature_acts = self.sae.encode(activationCache[self.sae.cfg.hook_name])
return self.sae.decode(sv_feature_acts), sv_feature_acts
def _hooked_generate(self, prompt_batch, fwd_hooks, seed=None, **kwargs):
if seed is not None:
torch.manual_seed(seed)
with self.model.hooks(fwd_hooks=fwd_hooks):
tokenized = self.model.to_tokens(prompt_batch)
result = self.model.generate(
stop_at_eos=False, # avoids a bug on MPS
input=tokenized,
max_new_tokens=50,
do_sample=True,
**kwargs)
return result
def _get_features(self, sv_feature_activations):
# return torch.topk(sv_feature_acts, 1).indices.tolist()
features = torch.topk(sv_feature_activations, 1).indices
print(f'features that align with the text prompt: {features}')
print("pump the features into the tool that gives you the words associated with each feature")
return features
def _get_steering_hook(self, feature, sae_out):
coeff = self.coeff
steering_vector = self.sae.W_dec[feature]
steering_vector = steering_vector[0]
def steering_hook(resid_pre, hook):
if resid_pre.shape[1] == 1:
return
position = sae_out.shape[1]
# using our steering vector and applying the coefficient
resid_pre[:, :position - 1, :] += coeff * steering_vector
return steering_hook
def _get_steering_hooks(self):
# TODO: refactor this. It works because sae_out.shape[1] = sv_feature_acts.shape[1] = len(features[0])
# you can manipulate views to retrieve hooks more cleanly
# and not use the seperate function _get_steering_hook()
sae_out, sv_feature_acts = self._get_sae_out_and_feature_activations()
features = self._get_features(sv_feature_acts)
steering_hooks = [self._get_steering_hook(feature, sae_out) for feature in features[0]]
return steering_hooks
def _run_generate(self, example_prompt, steering_on: bool):
self.model.reset_hooks()
if steering_on:
steer_hooks = self._get_steering_hooks()
editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks]
print(f"steering by {len(editing_hooks)} hooks")
res = self._hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **self.sampling_kwargs)
else:
tokenized = self.model.to_tokens([example_prompt])
res = self.model.generate(
stop_at_eos=False, # avoids a bug on MPS
input=tokenized,
max_new_tokens=50,
do_sample=True,
**self.sampling_kwargs)
# Print results, removing the ugly beginning of sequence token
res_str = self.model.to_string(res[:, 1:])
response = ("\n\n" + "-" * 80 + "\n\n").join(res_str)
print(response)
return response
def generate(self, message: str, steering_on: bool):
return self._run_generate(message, steering_on)
# MODEL = "gemma-2b"
# PRETRAINED_SAE = "gemma-2b-res-jb"
MODEL = "gpt2-small"
PRETRAINED_SAE = "gpt2-small-res-jb"
LAYER = 10
chatbot_model = Inference(MODEL, PRETRAINED_SAE, LAYER)
import time
import gradio as gr
default_image = "Hexter-Hackathon.png"
def slow_echo(message, history):
result = chatbot_model.generate(message, False)
for i in range(len(result)):
time.sleep(0.01)
yield result[: i + 1]
def slow_echo_steering(message, history):
result = chatbot_model.generate(message, True)
for i in range(len(result)):
time.sleep(0.01)
yield result[: i + 1]
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("*STANDARD HEXTER BOT*")
with gr.Row():
chatbot = gr.ChatInterface(
slow_echo,
chatbot=gr.Chatbot(min_width=1000),
textbox=gr.Textbox(placeholder="Ask Hexter anything!", min_width=1000),
theme="soft",
cache_examples=False,
retry_btn=None,
clear_btn=None,
undo_btn=None,
)
with gr.Row():
gr.Markdown("*STEERED HEXTER BOT*")
with gr.Row():
chatbot_steered = gr.ChatInterface(
slow_echo_steering,
chatbot=gr.Chatbot(min_width=1000),
textbox=gr.Textbox(placeholder="Ask Hexter anything!", min_width=1000),
theme="soft",
cache_examples=False,
retry_btn=None,
clear_btn=None,
undo_btn=None,
)
with gr.Row():
steering_prompt = gr.Textbox(label="Steering prompt", value="Golden Gate Bridge")
with gr.Row():
coeff = gr.Slider(1, 1000, 300, label="Coefficient", info="Coefficient is..", interactive=True)
with gr.Row():
temp = gr.Slider(0, 5, 1, label="Temperature", info="Temperature is..", interactive=True)
temp.change(chatbot_model.set_temperature, inputs=[temp], outputs=[])
coeff.change(chatbot_model.set_coeff, inputs=[coeff], outputs=[])
chatbot_model.set_steering_vector_prompt(steering_prompt.value)
steering_prompt.change(chatbot_model.set_steering_vector_prompt, inputs=[steering_prompt], outputs=[])
demo.queue()
demo.launch(debug=True)
if __name__ == "__main__":
demo.launch(allowed_paths=["/"])