frankaging
minor
971ace8
import os, json, random
import torch
import gradio as gr
import spaces
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import login, hf_hub_download
import pyreft
import pyvene as pv
from threading import Thread
from typing import Iterator
import torch.nn.functional as F
HF_TOKEN = os.environ.get("HF_TOKEN")
login(token=HF_TOKEN)
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 128 # smaller default to save memory
MAX_INPUT_TOKEN_LENGTH = 4096
css = """
#alert-message textarea {
background-color: #e8f4ff;
border: 1px solid #cce5ff;
color: #084298;
font-size: 1.1em;
padding: 12px;
border-radius: 4px;
font-weight: 500;
}
.concept-help {
font-size: 0.9em;
color: #666;
margin-top: 4px;
font-style: italic;
}
"""
def load_jsonl(jsonl_path):
jsonl_data = []
with open(jsonl_path, 'r') as f:
for line in f:
data = json.loads(line)
jsonl_data.append(data)
return jsonl_data
class Steer(pv.SourcelessIntervention):
"""Steer model via activation addition"""
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
self.proj = torch.nn.Linear(
self.embed_dim, kwargs["latent_dim"], bias=False)
self.subspace_generator = kwargs["subspace_generator"]
def steer(self, base, source=None, subspaces=None):
if subspaces["steer"]["subspace_gen_inputs"] is not None:
# we call our subspace generator to generate the subspace on-the-fly.
raw_steering_vec = self.subspace_generator(
subspaces["steer"]["subspace_gen_inputs"]["input_ids"],
subspaces["steer"]["subspace_gen_inputs"]["attention_mask"],
)[0]
steering_vec = torch.tensor(subspaces["steer"]["mag"]) * \
raw_steering_vec.unsqueeze(dim=0)
return base + steering_vec
else:
steering_vec = torch.tensor(subspaces["steer"]["mag"]) * \
self.proj.weight[subspaces["steer"]["idx"]].unsqueeze(dim=0)
return base + steering_vec
def forward(self, base, source=None, subspaces=None):
if subspaces == None:
return base
if subspaces["detect"] is not None:
if subspaces["detect"]["subspace_gen_inputs"] is not None:
# we call our subspace generator to generate the subspace on-the-fly.
raw_detection_vec = self.subspace_generator(
subspaces["detect"]["subspace_gen_inputs"]["input_ids"],
subspaces["detect"]["subspace_gen_inputs"]["attention_mask"],
)[0].unsqueeze(dim=-1)
else:
raw_detection_vec = self.proj.weight[subspaces["detect"]["idx"]].unsqueeze(dim=-1)
print(base.shape)
print(raw_detection_vec.shape)
detection_latent = torch.matmul(base, raw_detection_vec.to(base.dtype)).squeeze(dim=-1) # (batch_size, seq, 1) -> (batch_size, seq)
max_latent = torch.max(detection_latent, dim=-1).values[0] # (batch_size, seq) -> (batch_size)
print("max_latent", max_latent)
if max_latent > torch.tensor(subspaces["detect"]["mag"]):
print("Detected!")
return self.steer(base, source, subspaces)
else:
return base
else:
return self.steer(base, source, subspaces)
class RegressionWrapper(torch.nn.Module):
def __init__(self, base_model, hidden_size, output_dim):
super().__init__()
self.base_model = base_model
self.regression_head = torch.nn.Linear(hidden_size, output_dim)
def forward(self, input_ids, attention_mask):
outputs = self.base_model.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
last_hiddens = outputs.hidden_states[-1]
last_token_representations = last_hiddens[:, -1]
preds = self.regression_head(last_token_representations)
preds = F.normalize(preds, p=2, dim=-1)
return preds
# Check GPU
if not torch.cuda.is_available():
print("Warning: Running on CPU, may be slow.")
# Load model & dictionary
model_id = "google/gemma-2-2b-it"
pv_model = None
tokenizer = None
concept_list = []
concept_id_map = {}
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Download dictionary
weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl")
params = torch.load(weight_path).cuda()
md = load_jsonl(meta_path)
concept_list = [item["concept"] for item in md]
concept_id_map = {}
# the reason to reindex is because there is one concept that is missing.
concept_reindex = 0
for item in md:
concept_id_map[item["concept"]] = concept_reindex
concept_reindex += 1
# load subspace generator.
base_tokenizer = AutoTokenizer.from_pretrained(
f"google/gemma-2-2b", model_max_length=512)
config = AutoConfig.from_pretrained("google/gemma-2-2b")
base_model = AutoModelForCausalLM.from_config(config)
subspace_generator_weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res-generator", filename="l20/weight.pt")
hidden_size = base_model.config.hidden_size
subspace_generator = RegressionWrapper(
base_model, hidden_size, hidden_size).bfloat16().to("cuda")
subspace_generator.load_state_dict(torch.load(subspace_generator_weight_path))
print(f"Loading model from saved file {subspace_generator_weight_path}")
_ = subspace_generator.eval()
steer = Steer(
embed_dim=params.shape[0], latent_dim=params.shape[1],
subspace_generator=subspace_generator)
steer.proj.weight.data = params.float()
pv_model = pv.IntervenableModel({
"component": f"model.layers[20].output",
"intervention": steer}, model=model)
terminators = [tokenizer.eos_token_id] if tokenizer else []
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
detection_list: list[dict],
steering_list: list[dict],
max_new_tokens: int=DEFAULT_MAX_NEW_TOKENS,
) -> Iterator[str]:
# limit to last 4 turns
start_idx = max(0, len(chat_history) - 4)
recent_history = chat_history[start_idx:]
# build list of messages
messages = []
for rh in recent_history:
messages.append({"role": rh["role"], "content": rh["content"]})
messages.append({"role": "user", "content": message})
input_ids = torch.tensor([tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True)]).cuda()
# trim if needed
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
yield "[Truncated prior text]\n"
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
print("detection_list: ", detection_list)
print("steering_list: ", steering_list)
generate_kwargs = {
"base": {"input_ids": input_ids},
"unit_locations": None,
"max_new_tokens": max_new_tokens,
"intervene_on_prompt": True,
"subspaces": [
{
"detect": {
"idx": int(detection_list[0]["idx"]),
"mag": detection_list[0]["internal_mag"]*50,
"subspace_gen_inputs": base_tokenizer(detection_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \
if detection_list[0]["subspace_gen_text"] is not None else None
} if detection_list else None,
"steer": {
"idx": int(steering_list[0]["idx"]),
"mag": steering_list[0]["internal_mag"]*50,
"subspace_gen_inputs": base_tokenizer(steering_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \
if steering_list[0]["subspace_gen_text"] is not None else None
}
}
] if steering_list else None, # if steering is not provided, we do not steer.
"streamer": streamer,
"do_sample": True
}
t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
t.start()
partial_text = []
for token_str in streamer:
partial_text.append(token_str)
yield "".join(partial_text)
def filter_concepts(search_text: str):
if not search_text.strip():
return concept_list[:500]
filtered = [c for c in concept_list if search_text.lower() in c.lower()]
return filtered[:500]
def add_concept_to_list(selected_concept, user_slider_val, current_list):
if not selected_concept:
return current_list
selected_concept_text = None
if selected_concept.startswith("[New] "):
selected_concept_text = selected_concept[6:]
idx = 0
else:
idx = concept_id_map[selected_concept]
internal_mag = user_slider_val
new_entry = {
"text": selected_concept,
"idx": idx,
"display_mag": user_slider_val,
"internal_mag": internal_mag,
"subspace_gen_text": selected_concept_text
}
# Add to the beginning of the list
current_list = [new_entry]
return current_list
def update_dropdown_choices(search_text, is_detection=False):
filtered = filter_concepts(search_text)
if not filtered or len(filtered) == 0:
alert_message = (
"Good news! Based on the topic you provided, we will automatically generate a detector for you!"
) if is_detection else (
"Good news! Based on the topic you provided, we will automatically generate a steering vector. Try it out by starting a chat!"
)
return gr.update(
choices=[],
value=None,
interactive=True
), gr.Textbox(
label="No matching topics found",
value=alert_message,
lines=3,
interactive=False,
visible=True,
elem_id="alert-message"
)
return gr.update(
choices=filtered,
value=filtered[0],
interactive=True,
visible=True
), gr.Textbox(visible=False)
with gr.Blocks(css=css, fill_height=True) as demo:
selected_detection = gr.State([])
selected_subspaces = gr.State([])
with gr.Row(min_height=500, equal_height=True):
# Left side: chat area
with gr.Column(scale=7):
gr.Markdown("""# Conditionally Steer AI Responses Based on Topics""")
gr.Markdown("""This is an experimental chatbot that you can steer using topics you care about:
Step 1: Choose a topic (e.g., "Google") to detect
Step 2: Choose a topic (e.g., "ethics") you want the model to discuss when the previous topic comes up
We intervene on Gemma-2-2B-it by adding steering vectors to the residual stream at layer 20.""")
chat_interface = gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(),
textbox=gr.Textbox(placeholder="List some search engines with their pros and cons", container=True, scale=7, submit_btn=True),
additional_inputs=[selected_detection, selected_subspaces],
)
# Right side: concept detection and steering
with gr.Column(scale=3):
gr.Markdown("""#### Step 1: Choose a topic the model needs to recognize.""")
with gr.Group():
detect_search = gr.Textbox(
label="Search for topics to detect",
placeholder="Try: 'Google'",
lines=1,
)
detect_msg = gr.TextArea(visible=False)
detect_dropdown = gr.Dropdown(
label="Choose a topic to detect (Click to see more!)",
interactive=True,
allow_custom_value=False,
)
detect_threshold = gr.Slider(
label="Detection sensitivity",
minimum=0,
maximum=1,
step=0.1,
value=0.5,
)
gr.Markdown("---")
gr.Markdown("""#### Step 2: Choose another topic the model needs to discuss when it detects the topic above.""")
with gr.Group():
search_box = gr.Textbox(
label="Search topics to steer",
placeholder="Try: 'ethics'",
lines=1,
)
msg = gr.TextArea(visible=False)
concept_dropdown = gr.Dropdown(
label="Choose a topic to steer the model (Click to see more!)",
interactive=True,
allow_custom_value=False,
)
concept_magnitude = gr.Slider(
label="Steering intensity",
minimum=-5,
maximum=5,
step=0.1,
value=3.5,
)
# Wire up events for detection
detect_search.input(
lambda x: update_dropdown_choices(x, is_detection=True),
[detect_search],
[detect_dropdown, detect_msg]
).then(
add_concept_to_list,
[detect_dropdown, detect_threshold, selected_detection],
[selected_detection]
)
detect_dropdown.select(
add_concept_to_list,
[detect_dropdown, detect_threshold, selected_detection],
[selected_detection]
)
detect_threshold.input(
add_concept_to_list,
[detect_dropdown, detect_threshold, selected_detection],
[selected_detection]
)
# Wire up events for steering
search_box.input(
lambda x: update_dropdown_choices(x, is_detection=False),
[search_box],
[concept_dropdown, msg]
).then(
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces]
)
concept_dropdown.select(
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces]
)
concept_magnitude.input(
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces]
)
demo.launch(share=True, height=1000)