Spaces:
Sleeping
Sleeping
from transformer_lens import HookedTransformer | |
from sae_lens import SAE | |
import torch | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
class Inference: | |
def __init__(self, model, pretrained_sae, layer): | |
self.layer = layer | |
if model == "gemma-2b": | |
self.sae_id = f"blocks.{layer}.hook_resid_post" | |
elif model == "gpt2-small": | |
print(f"using {model}") | |
self.sae_id = f"blocks.{0}.hook_resid_pre" | |
self.sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0) | |
self.set_coeff(1) | |
self.set_model(model) | |
self.set_SAE(pretrained_sae) | |
def set_model(self, model): | |
self.model = HookedTransformer.from_pretrained(model, device = device) | |
def set_coeff(self, coeff): | |
self.coeff = coeff | |
def set_temperature(self, temperature): | |
self.sampling_kwargs['temperature'] = temperature | |
def set_steering_vector_prompt(self, prompt: str): | |
self.steering_vector_prompt = prompt | |
def set_SAE(self, sae_name): | |
sae, cfg_dict, _ = SAE.from_pretrained( | |
release = sae_name, | |
sae_id = self.sae_id, | |
device = device | |
) | |
self.sae = sae | |
self.cfg_dict = cfg_dict | |
def _get_sae_out_and_feature_activations(self): | |
# given the words in steering_vector_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated | |
sv_logits, activationCache = self.model.run_with_cache(self.steering_vector_prompt, prepend_bos=True) | |
sv_feature_acts = self.sae.encode(activationCache[self.sae.cfg.hook_name]) | |
return self.sae.decode(sv_feature_acts), sv_feature_acts | |
def _hooked_generate(self, prompt_batch, fwd_hooks, seed=None, **kwargs): | |
if seed is not None: | |
torch.manual_seed(seed) | |
with self.model.hooks(fwd_hooks=fwd_hooks): | |
tokenized = self.model.to_tokens(prompt_batch) | |
result = self.model.generate( | |
stop_at_eos=False, # avoids a bug on MPS | |
input=tokenized, | |
max_new_tokens=50, | |
do_sample=True, | |
**kwargs) | |
return result | |
def _get_features(self, sv_feature_activations): | |
# return torch.topk(sv_feature_acts, 1).indices.tolist() | |
features = torch.topk(sv_feature_activations, 1).indices | |
print(f'features that align with the text prompt: {features}') | |
print("pump the features into the tool that gives you the words associated with each feature") | |
return features | |
def _get_steering_hook(self, feature, sae_out): | |
coeff = self.coeff | |
steering_vector = self.sae.W_dec[feature] | |
steering_vector = steering_vector[0] | |
def steering_hook(resid_pre, hook): | |
if resid_pre.shape[1] == 1: | |
return | |
position = sae_out.shape[1] | |
# using our steering vector and applying the coefficient | |
resid_pre[:, :position - 1, :] += coeff * steering_vector | |
return steering_hook | |
def _get_steering_hooks(self): | |
# TODO: refactor this. It works because sae_out.shape[1] = sv_feature_acts.shape[1] = len(features[0]) | |
# you can manipulate views to retrieve hooks more cleanly | |
# and not use the seperate function _get_steering_hook() | |
sae_out, sv_feature_acts = self._get_sae_out_and_feature_activations() | |
features = self._get_features(sv_feature_acts) | |
steering_hooks = [self._get_steering_hook(feature, sae_out) for feature in features[0]] | |
return steering_hooks | |
def _run_generate(self, example_prompt, steering_on: bool): | |
self.model.reset_hooks() | |
if steering_on: | |
steer_hooks = self._get_steering_hooks() | |
editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks] | |
print(f"steering by {len(editing_hooks)} hooks") | |
res = self._hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **self.sampling_kwargs) | |
else: | |
tokenized = self.model.to_tokens([example_prompt]) | |
res = self.model.generate( | |
stop_at_eos=False, # avoids a bug on MPS | |
input=tokenized, | |
max_new_tokens=50, | |
do_sample=True, | |
**self.sampling_kwargs) | |
# Print results, removing the ugly beginning of sequence token | |
res_str = self.model.to_string(res[:, 1:]) | |
response = ("\n\n" + "-" * 80 + "\n\n").join(res_str) | |
print(response) | |
return response | |
def generate(self, message: str, steering_on: bool): | |
return self._run_generate(message, steering_on) | |
# MODEL = "gemma-2b" | |
# PRETRAINED_SAE = "gemma-2b-res-jb" | |
MODEL = "gpt2-small" | |
PRETRAINED_SAE = "gpt2-small-res-jb" | |
LAYER = 10 | |
chatbot_model = Inference(MODEL, PRETRAINED_SAE, LAYER) | |
import time | |
import gradio as gr | |
default_image = "Hexter-Hackathon.png" | |
def slow_echo(message, history): | |
result = chatbot_model.generate(message, False) | |
for i in range(len(result)): | |
time.sleep(0.01) | |
yield result[: i + 1] | |
def slow_echo_steering(message, history): | |
result = chatbot_model.generate(message, True) | |
for i in range(len(result)): | |
time.sleep(0.01) | |
yield result[: i + 1] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown("*STANDARD HEXTER BOT*") | |
with gr.Row(): | |
chatbot = gr.ChatInterface( | |
slow_echo, | |
chatbot=gr.Chatbot(min_width=1000), | |
textbox=gr.Textbox(placeholder="Ask Hexter anything!", min_width=1000), | |
theme="soft", | |
cache_examples=False, | |
retry_btn=None, | |
clear_btn=None, | |
undo_btn=None, | |
) | |
with gr.Row(): | |
gr.Markdown("*STEERED HEXTER BOT*") | |
with gr.Row(): | |
chatbot_steered = gr.ChatInterface( | |
slow_echo_steering, | |
chatbot=gr.Chatbot(min_width=1000), | |
textbox=gr.Textbox(placeholder="Ask Hexter anything!", min_width=1000), | |
theme="soft", | |
cache_examples=False, | |
retry_btn=None, | |
clear_btn=None, | |
undo_btn=None, | |
) | |
with gr.Row(): | |
steering_prompt = gr.Textbox(label="Steering prompt", value="Golden Gate Bridge") | |
with gr.Row(): | |
coeff = gr.Slider(1, 1000, 300, label="Coefficient", info="Coefficient is..", interactive=True) | |
with gr.Row(): | |
temp = gr.Slider(0, 5, 1, label="Temperature", info="Temperature is..", interactive=True) | |
temp.change(chatbot_model.set_temperature, inputs=[temp], outputs=[]) | |
coeff.change(chatbot_model.set_coeff, inputs=[coeff], outputs=[]) | |
chatbot_model.set_steering_vector_prompt(steering_prompt.value) | |
steering_prompt.change(chatbot_model.set_steering_vector_prompt, inputs=[steering_prompt], outputs=[]) | |
demo.queue() | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
demo.launch(allowed_paths=["/"]) | |