import os, json, random import torch import gradio as gr import spaces from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from huggingface_hub import login, hf_hub_download import pyreft import pyvene as pv from threading import Thread from typing import Iterator import torch.nn.functional as F HF_TOKEN = os.environ.get("HF_TOKEN") login(token=HF_TOKEN) MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 128 # smaller default to save memory MAX_INPUT_TOKEN_LENGTH = 4096 css = """ #alert-message textarea { background-color: #e8f4ff; border: 1px solid #cce5ff; color: #084298; font-size: 1.1em; padding: 12px; border-radius: 4px; font-weight: 500; } .concept-help { font-size: 0.9em; color: #666; margin-top: 4px; font-style: italic; } """ def load_jsonl(jsonl_path): jsonl_data = [] with open(jsonl_path, 'r') as f: for line in f: data = json.loads(line) jsonl_data.append(data) return jsonl_data class Steer(pv.SourcelessIntervention): """Steer model via activation addition""" def __init__(self, **kwargs): super().__init__(**kwargs, keep_last_dim=True) self.proj = torch.nn.Linear( self.embed_dim, kwargs["latent_dim"], bias=False) self.subspace_generator = kwargs["subspace_generator"] def steer(self, base, source=None, subspaces=None): if subspaces["steer"]["subspace_gen_inputs"] is not None: # we call our subspace generator to generate the subspace on-the-fly. raw_steering_vec = self.subspace_generator( subspaces["steer"]["subspace_gen_inputs"]["input_ids"], subspaces["steer"]["subspace_gen_inputs"]["attention_mask"], )[0] steering_vec = torch.tensor(subspaces["steer"]["mag"]) * \ raw_steering_vec.unsqueeze(dim=0) return base + steering_vec else: steering_vec = torch.tensor(subspaces["steer"]["mag"]) * \ self.proj.weight[subspaces["steer"]["idx"]].unsqueeze(dim=0) return base + steering_vec def forward(self, base, source=None, subspaces=None): if subspaces == None: return base if subspaces["detect"] is not None: if subspaces["detect"]["subspace_gen_inputs"] is not None: # we call our subspace generator to generate the subspace on-the-fly. raw_detection_vec = self.subspace_generator( subspaces["detect"]["subspace_gen_inputs"]["input_ids"], subspaces["detect"]["subspace_gen_inputs"]["attention_mask"], )[0].unsqueeze(dim=-1) else: raw_detection_vec = self.proj.weight[subspaces["detect"]["idx"]].unsqueeze(dim=-1) print(base.shape) print(raw_detection_vec.shape) detection_latent = torch.matmul(base, raw_detection_vec.to(base.dtype)).squeeze(dim=-1) # (batch_size, seq, 1) -> (batch_size, seq) max_latent = torch.max(detection_latent, dim=-1).values[0] # (batch_size, seq) -> (batch_size) print("max_latent", max_latent) if max_latent > torch.tensor(subspaces["detect"]["mag"]): print("Detected!") return self.steer(base, source, subspaces) else: return base else: return self.steer(base, source, subspaces) class RegressionWrapper(torch.nn.Module): def __init__(self, base_model, hidden_size, output_dim): super().__init__() self.base_model = base_model self.regression_head = torch.nn.Linear(hidden_size, output_dim) def forward(self, input_ids, attention_mask): outputs = self.base_model.model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) last_hiddens = outputs.hidden_states[-1] last_token_representations = last_hiddens[:, -1] preds = self.regression_head(last_token_representations) preds = F.normalize(preds, p=2, dim=-1) return preds # Check GPU if not torch.cuda.is_available(): print("Warning: Running on CPU, may be slow.") # Load model & dictionary model_id = "google/gemma-2-2b-it" pv_model = None tokenizer = None concept_list = [] concept_id_map = {} if torch.cuda.is_available(): model = AutoModelForCausalLM.from_pretrained( model_id, device_map="cuda", torch_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id) # Download dictionary weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt") meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl") params = torch.load(weight_path).cuda() md = load_jsonl(meta_path) concept_list = [item["concept"] for item in md] concept_id_map = {} # the reason to reindex is because there is one concept that is missing. concept_reindex = 0 for item in md: concept_id_map[item["concept"]] = concept_reindex concept_reindex += 1 # load subspace generator. base_tokenizer = AutoTokenizer.from_pretrained( f"google/gemma-2-2b", model_max_length=512) config = AutoConfig.from_pretrained("google/gemma-2-2b") base_model = AutoModelForCausalLM.from_config(config) subspace_generator_weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res-generator", filename="l20/weight.pt") hidden_size = base_model.config.hidden_size subspace_generator = RegressionWrapper( base_model, hidden_size, hidden_size).bfloat16().to("cuda") subspace_generator.load_state_dict(torch.load(subspace_generator_weight_path)) print(f"Loading model from saved file {subspace_generator_weight_path}") _ = subspace_generator.eval() steer = Steer( embed_dim=params.shape[0], latent_dim=params.shape[1], subspace_generator=subspace_generator) steer.proj.weight.data = params.float() pv_model = pv.IntervenableModel({ "component": f"model.layers[20].output", "intervention": steer}, model=model) terminators = [tokenizer.eos_token_id] if tokenizer else [] @spaces.GPU def generate( message: str, chat_history: list[tuple[str, str]], detection_list: list[dict], steering_list: list[dict], max_new_tokens: int=DEFAULT_MAX_NEW_TOKENS, ) -> Iterator[str]: # limit to last 4 turns start_idx = max(0, len(chat_history) - 4) recent_history = chat_history[start_idx:] # build list of messages messages = [] for rh in recent_history: messages.append({"role": rh["role"], "content": rh["content"]}) messages.append({"role": "user", "content": message}) input_ids = torch.tensor([tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True)]).cuda() # trim if needed if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] yield "[Truncated prior text]\n" streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) print("detection_list: ", detection_list) print("steering_list: ", steering_list) generate_kwargs = { "base": {"input_ids": input_ids}, "unit_locations": None, "max_new_tokens": max_new_tokens, "intervene_on_prompt": True, "subspaces": [ { "detect": { "idx": int(detection_list[0]["idx"]), "mag": detection_list[0]["internal_mag"]*50, "subspace_gen_inputs": base_tokenizer(detection_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \ if detection_list[0]["subspace_gen_text"] is not None else None } if detection_list else None, "steer": { "idx": int(steering_list[0]["idx"]), "mag": steering_list[0]["internal_mag"]*50, "subspace_gen_inputs": base_tokenizer(steering_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \ if steering_list[0]["subspace_gen_text"] is not None else None } } ] if steering_list else None, # if steering is not provided, we do not steer. "streamer": streamer, "do_sample": True } t = Thread(target=pv_model.generate, kwargs=generate_kwargs) t.start() partial_text = [] for token_str in streamer: partial_text.append(token_str) yield "".join(partial_text) def filter_concepts(search_text: str): if not search_text.strip(): return concept_list[:500] filtered = [c for c in concept_list if search_text.lower() in c.lower()] return filtered[:500] def add_concept_to_list(selected_concept, user_slider_val, current_list): if not selected_concept: return current_list selected_concept_text = None if selected_concept.startswith("[New] "): selected_concept_text = selected_concept[6:] idx = 0 else: idx = concept_id_map[selected_concept] internal_mag = user_slider_val new_entry = { "text": selected_concept, "idx": idx, "display_mag": user_slider_val, "internal_mag": internal_mag, "subspace_gen_text": selected_concept_text } # Add to the beginning of the list current_list = [new_entry] return current_list def update_dropdown_choices(search_text, is_detection=False): filtered = filter_concepts(search_text) if not filtered or len(filtered) == 0: alert_message = ( "Good news! Based on the topic you provided, we will automatically generate a detector for you!" ) if is_detection else ( "Good news! Based on the topic you provided, we will automatically generate a steering vector. Try it out by starting a chat!" ) return gr.update( choices=[], value=None, interactive=True ), gr.Textbox( label="No matching topics found", value=alert_message, lines=3, interactive=False, visible=True, elem_id="alert-message" ) return gr.update( choices=filtered, value=filtered[0], interactive=True, visible=True ), gr.Textbox(visible=False) with gr.Blocks(css=css, fill_height=True) as demo: selected_detection = gr.State([]) selected_subspaces = gr.State([]) with gr.Row(min_height=500, equal_height=True): # Left side: chat area with gr.Column(scale=7): gr.Markdown("""# Conditionally Steer AI Responses Based on Topics""") gr.Markdown("""This is an experimental chatbot that you can steer using topics you care about: Step 1: Choose a topic (e.g., "Google") to detect Step 2: Choose a topic (e.g., "ethics") you want the model to discuss when the previous topic comes up We intervene on Gemma-2-2B-it by adding steering vectors to the residual stream at layer 20.""") chat_interface = gr.ChatInterface( fn=generate, chatbot=gr.Chatbot(), textbox=gr.Textbox(placeholder="List some search engines with their pros and cons", container=True, scale=7, submit_btn=True), additional_inputs=[selected_detection, selected_subspaces], ) # Right side: concept detection and steering with gr.Column(scale=3): gr.Markdown("""#### Step 1: Choose a topic the model needs to recognize.""") with gr.Group(): detect_search = gr.Textbox( label="Search for topics to detect", placeholder="Try: 'Google'", lines=1, ) detect_msg = gr.TextArea(visible=False) detect_dropdown = gr.Dropdown( label="Choose a topic to detect (Click to see more!)", interactive=True, allow_custom_value=False, ) detect_threshold = gr.Slider( label="Detection sensitivity", minimum=0, maximum=1, step=0.1, value=0.5, ) gr.Markdown("---") gr.Markdown("""#### Step 2: Choose another topic the model needs to discuss when it detects the topic above.""") with gr.Group(): search_box = gr.Textbox( label="Search topics to steer", placeholder="Try: 'ethics'", lines=1, ) msg = gr.TextArea(visible=False) concept_dropdown = gr.Dropdown( label="Choose a topic to steer the model (Click to see more!)", interactive=True, allow_custom_value=False, ) concept_magnitude = gr.Slider( label="Steering intensity", minimum=-5, maximum=5, step=0.1, value=3.5, ) # Wire up events for detection detect_search.input( lambda x: update_dropdown_choices(x, is_detection=True), [detect_search], [detect_dropdown, detect_msg] ).then( add_concept_to_list, [detect_dropdown, detect_threshold, selected_detection], [selected_detection] ) detect_dropdown.select( add_concept_to_list, [detect_dropdown, detect_threshold, selected_detection], [selected_detection] ) detect_threshold.input( add_concept_to_list, [detect_dropdown, detect_threshold, selected_detection], [selected_detection] ) # Wire up events for steering search_box.input( lambda x: update_dropdown_choices(x, is_detection=False), [search_box], [concept_dropdown, msg] ).then( add_concept_to_list, [concept_dropdown, concept_magnitude, selected_subspaces], [selected_subspaces] ) concept_dropdown.select( add_concept_to_list, [concept_dropdown, concept_magnitude, selected_subspaces], [selected_subspaces] ) concept_magnitude.input( add_concept_to_list, [concept_dropdown, concept_magnitude, selected_subspaces], [selected_subspaces] ) demo.launch(share=True, height=1000)